{-# LANGUAGE RecordWildCards, ViewPatterns, NamedFieldPuns #-}
{- HLINT ignore "Use camelCase" -}

-- | Module containing the plugin.
module RecordDotPreprocessor(plugin) where

import Data.Generics.Uniplate.Data
import Data.List.Extra
import Data.Tuple.Extra
import Compat
import Bag
import qualified GHC
import qualified GhcPlugins as GHC
import HsSyn
import SrcLoc
import TcEvidence


---------------------------------------------------------------------
-- PLUGIN WRAPPER

-- | GHC plugin.
plugin :: GHC.Plugin
plugin = GHC.defaultPlugin
    { GHC.parsedResultAction = \_cliOptions _modSummary x -> return x{GHC.hpm_module = onModule <$> GHC.hpm_module x}
    , GHC.pluginRecompile = GHC.purePlugin
    }


---------------------------------------------------------------------
-- PLUGIN GUTS

setL :: SrcSpan -> GenLocated SrcSpan e -> GenLocated SrcSpan e
setL l (L _ x) = L l x

mod_records :: GHC.ModuleName
mod_records = GHC.mkModuleName "GHC.Records.Extra"

var_HasField, var_hasField, var_getField, var_setField, var_dot :: GHC.RdrName
var_HasField = GHC.mkRdrQual mod_records $ GHC.mkClsOcc "HasField"
var_hasField = GHC.mkRdrUnqual $ GHC.mkVarOcc "hasField"
var_getField = GHC.mkRdrQual mod_records $ GHC.mkVarOcc "getField"
var_setField = GHC.mkRdrQual mod_records $ GHC.mkVarOcc "setField"
var_dot = GHC.mkRdrUnqual $ GHC.mkVarOcc "."


onModule :: HsModule GhcPs -> HsModule GhcPs
onModule x = x { hsmodImports = onImports $ hsmodImports x
               , hsmodDecls = concatMap onDecl $ hsmodDecls x
               }


onImports :: [LImportDecl GhcPs] -> [LImportDecl GhcPs]
onImports = (:) $ noL $ GHC.ImportDecl GHC.NoExt GHC.NoSourceText (noL mod_records)
    Nothing False False True {- qualified -} True {- implicit -} Nothing Nothing


{-
instance Z.HasField "name" (Company) (String) where hasField _r = (\_x -> _r{name=_x}, (name:: (Company) -> String) _r)

instance HasField "selector" Record Field where
    hasField r = (\x -> r{selector=x}, (name :: Record -> Field) r)
-}
instanceTemplate :: FieldOcc GhcPs -> HsType GhcPs -> HsType GhcPs -> InstDecl GhcPs
instanceTemplate selector record field = ClsInstD noE $ ClsInstDecl noE (HsIB noE typ) (unitBag has) [] [] [] Nothing
    where
        typ = mkHsAppTys
            (noL (HsTyVar noE GHC.NotPromoted (noL var_HasField)))
            [noL (HsTyLit noE (HsStrTy GHC.NoSourceText (GHC.occNameFS $ GHC.occName $ unLoc $ rdrNameFieldOcc selector)))
            ,noL record
            ,noL field
            ]

        has :: LHsBindLR GhcPs GhcPs
        has = noL $ FunBind noE (noL var_hasField) (mg1 eqn) WpHole []
            where
                eqn = Match
                    { m_ext     = noE
                    , m_ctxt    = FunRhs (noL var_hasField) GHC.Prefix NoSrcStrict
                    , m_pats    = compat_m_pats [VarPat noE $ noL vR]
                    , m_grhss   = GRHSs noE [noL $ GRHS noE [] $ noL $ ExplicitTuple noE [noL $ Present noE set, noL $ Present noE get] GHC.Boxed] (noL $ EmptyLocalBinds noE)
                    }
                set = noL $ HsLam noE $ mg1 Match
                    { m_ext     = noE
                    , m_ctxt    = LambdaExpr
                    , m_pats    = compat_m_pats [VarPat noE $ noL vX]
                    , m_grhss   = GRHSs noE [noL $ GRHS noE [] $ noL update] (noL $ EmptyLocalBinds noE)
                    }
                update = RecordUpd noE (noL $ GHC.HsVar noE $ noL vR)
                    [noL $ HsRecField (noL (Unambiguous noE (rdrNameFieldOcc selector))) (noL $ GHC.HsVar noE $ noL vX) False]
                get = mkApp
                    (mkParen $ mkTypeAnn (noL $ GHC.HsVar noE $ rdrNameFieldOcc selector) (noL $ HsFunTy noE (noL record) (noL field)))
                    (noL $ GHC.HsVar noE $ noL vR)

        mg1 :: Match GhcPs (LHsExpr GhcPs) -> MatchGroup GhcPs (LHsExpr GhcPs)
        mg1 x = MG noE (noL [noL x]) GHC.Generated

        vR = GHC.mkRdrUnqual $ GHC.mkVarOcc "r"
        vX = GHC.mkRdrUnqual $ GHC.mkVarOcc "x"


onDecl :: LHsDecl GhcPs -> [LHsDecl GhcPs]
onDecl o@(L _ (GHC.TyClD _ x)) = o :
    [ noL $ InstD noE $ instanceTemplate field (unLoc record) (unbang typ)
    | let fields = nubOrdOn (\(_,_,x,_) -> GHC.occNameFS $ GHC.rdrNameOcc $ unLoc $ rdrNameFieldOcc x) $ getFields x
    , (record, _, field, typ) <- fields]
onDecl x = [descendBi onExp x]

unbang :: HsType GhcPs -> HsType GhcPs
unbang (HsBangTy _ _ x) = unLoc x
unbang x = x

getFields :: TyClDecl GhcPs -> [(LHsType GhcPs, IdP GhcPs, FieldOcc GhcPs, HsType GhcPs)]
getFields DataDecl{tcdDataDefn=HsDataDefn{..}, ..} = concatMap ctor dd_cons
    where
        ctor (L _ ConDeclH98{con_args=RecCon (L _ fields),con_name=L _ name}) = concatMap (field name) fields
        ctor _ = []

        field name (L _ ConDeclField{cd_fld_type=L _ ty, ..}) = [(result, name, fld, ty) | L _ fld <- cd_fld_names]
        field _ _ = error "unknown field declaration in getFields"

        result = noL $ HsParTy noE $ foldl (\x y -> noL $ HsAppTy noE x $ hsLTyVarBndrToType y) (noL $ HsTyVar noE GHC.NotPromoted tcdLName) $ hsq_explicit tcdTyVars
getFields _ = []


-- At this point infix expressions have not had associativity/fixity applied, so they are bracketed
-- a + b + c ==> (a + b) + c
-- Therefore we need to deal with, in general:
-- x.y, where
-- x := a | a b | a.b | a + b
-- y := a | a b | a{b=1}
onExp :: LHsExpr GhcPs -> LHsExpr GhcPs
onExp (L o (OpApp _ lhs mid@(isDot -> True) rhs))
    | adjacent lhs mid, adjacent mid rhs
    , (lhsOp, lhs) <- getOpRHS $ onExp lhs
    , (lhsApp, lhs) <- getAppRHS lhs
    , (rhsApp, rhs) <- getAppLHS rhs
    , (rhsRec, rhs) <- getRec rhs
    , Just sel <- getSelector rhs
    = onExp $ lhsOp $ rhsApp $ lhsApp $ rhsRec $ mkParen $ setL o $ mkVar var_getField `mkAppType` sel `mkApp` lhs

-- Turn (.foo.bar) into getField calls
onExp (L o (SectionR _ mid@(isDot -> True) rhs))
    | adjacent mid rhs
    , srcSpanStart o == srcSpanStart (getLoc mid)
    , srcSpanEnd o == srcSpanEnd (getLoc rhs)
    , Just sels <- getSelectors rhs
    = setL o $ mkParen $ foldl1 (\x y -> noL $ OpApp noE x (mkVar var_dot) y) $ map (mkVar var_getField `mkAppType`) $ reverse sels

-- Turn a{b=c, d=e, ...} into successive setField calls
onExp (L _ RecordUpd{rupd_expr,rupd_flds=[]}) = onExp rupd_expr
onExp (L o upd@RecordUpd{rupd_expr,rupd_flds=L _ (HsRecField (fmap rdrNameAmbiguousFieldOcc -> lbl) arg pun):flds})
    | let sel = mkSelector lbl
    , let arg2 = if pun then noL $ HsVar noE lbl else arg
    , let expr = mkParen $ mkVar var_setField `mkAppType` sel `mkApp` mkParen rupd_expr `mkApp` arg2
    = onExp $ L o upd{rupd_expr=expr,rupd_flds=flds}

onExp x = descend onExp x


mkSelector :: Located GHC.RdrName -> LHsType GhcPs
mkSelector (L o x) = L o $ HsTyLit noE $ HsStrTy GHC.NoSourceText $ GHC.occNameFS $ GHC.rdrNameOcc x

getSelector :: LHsExpr GhcPs -> Maybe (LHsType GhcPs)
getSelector (L _ (HsVar _ (L o sym)))
    | not $ GHC.isQual sym
    = Just $ mkSelector $ L o sym
getSelector _ = Nothing

-- | Turn a.b.c into Just [a,b,c]
getSelectors :: LHsExpr GhcPs -> Maybe [LHsType GhcPs]
getSelectors (L _ (OpApp _ lhs mid@(isDot -> True) rhs))
    | adjacent lhs mid, adjacent mid rhs
    , Just post <- getSelector rhs
    , Just pre <- getSelectors lhs
    = Just $ pre ++ [post]
getSelectors x = (:[]) <$> getSelector x

-- | Lens on: f [x]
getAppRHS :: LHsExpr GhcPs -> (LHsExpr GhcPs -> LHsExpr GhcPs, LHsExpr GhcPs)
getAppRHS (L l (HsApp e x y)) = (L l . HsApp e x, y)
getAppRHS x = (id, x)

-- | Lens on: [f] x y z
getAppLHS :: LHsExpr GhcPs -> (LHsExpr GhcPs -> LHsExpr GhcPs, LHsExpr GhcPs)
getAppLHS (L l (HsApp e x y)) = first (\c -> L l . (\x -> HsApp e x y) . c) $ getAppLHS x
getAppLHS x = (id, x)

-- | Lens on: a + [b]
getOpRHS :: LHsExpr GhcPs -> (LHsExpr GhcPs -> LHsExpr GhcPs, LHsExpr GhcPs)
getOpRHS (L l (OpApp x y p z)) = (L l . OpApp x y p, z)
getOpRHS x = (id, x)

-- | Lens on: [r]{f1=x1}{f2=x2}
getRec :: LHsExpr GhcPs -> (LHsExpr GhcPs -> LHsExpr GhcPs, LHsExpr GhcPs)
getRec (L l r@RecordUpd{}) = first (\c x -> L l r{rupd_expr=c x}) $ getRec $ rupd_expr r
getRec x = (id, x)

-- | Is it equal to: .
isDot :: LHsExpr GhcPs -> Bool
isDot (L _ (HsVar _ (L _ op))) = op == var_dot
isDot _ = False

mkVar :: GHC.RdrName -> LHsExpr GhcPs
mkVar = noL . HsVar noE . noL

mkParen :: LHsExpr GhcPs -> LHsExpr GhcPs
mkParen = noL . HsPar noE

mkApp :: LHsExpr GhcPs -> LHsExpr GhcPs -> LHsExpr GhcPs
mkApp x y = noL $ HsApp noE x y

-- | Are the end of a and the start of b next to each other, no white space
adjacent :: Located a -> Located b -> Bool
adjacent (L a _) (L b _) = isGoodSrcSpan a && srcSpanEnd a == srcSpanStart b