{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE CPP             #-}
{-# LANGUAGE ViewPatterns    #-}
{-# LANGUAGE BangPatterns    #-}
module GHC.Plugin.LensThRewrite ( plugin, rewriteModule ) where
import Control.Arrow
import Control.Lens
import Data.Function                           (on)
import Data.List
import CoreSyn
import GhcPlugins
import HsDecls
import HsDumpAst
import HsExtension
import HsSyn
import OccName
import RdrName
import TcEvidence
import Var
plugin :: Plugin
plugin
  = defaultPlugin
  { parsedResultAction = \_ _ -> rewriteMakeLenses
  , pluginRecompile = purePlugin
  }
rewriteModule
   :: HsModule GhcPs
   -> HsModule GhcPs
rewriteModule module'
  = module' & decls %~ concatMap (modifyDecls module')
rewriteMakeLenses
   :: HsParsedModule
   -> Hsc HsParsedModule
rewriteMakeLenses parsed
  = pure
  $ parsed
  & moduleDecls
  %~ concatMap (modifyDecls (parsed ^. (hsMod . located)))
parsedModule :: Lens' HsParsedModule (Located (HsModule GhcPs))
parsedModule = lens hpm_module $ \r f -> r { hpm_module = f }
moduleDecls :: Lens' HsParsedModule [LHsDecl GhcPs]
moduleDecls = hsMod . located . decls
hsMod :: Lens' HsParsedModule (Located (HsModule GhcPs))
hsMod = lens hpm_module $ \f r -> f { hpm_module = r }
located :: Lens' (Located a) a
located = lens getter setter
  where
    getter (L _ r) = r
    setter (L x _) y = L x y
decls :: Lens' (HsModule a) [LHsDecl a]
decls = lens hsmodDecls $ \f r -> f { hsmodDecls = r }
modifyDecls :: HsModule GhcPs -> LHsDecl GhcPs -> [LHsDecl GhcPs]
modifyDecls m (L x decl) | not (isMakeLensesThSplice decl) = pure (L x decl)
modifyDecls m (L x decl) = emptyL <$> toDecls decl
  where
    toDecls :: HsDecl GhcPs -> [HsDecl GhcPs]
    toDecls decl = concat $ genDecls (getDecls m) =<< getMakeLensesSplices decl
        where
          genDecls decls type'
            | Just fields <- lookup type' decls = genLensCall type' <$> fields
            | otherwise = []
isMakeLensesThSplice :: HsDecl GhcPs -> Bool
isMakeLensesThSplice (SpliceD _ (SpliceDecl _ (L _ splice) _)) =
  case splice of
    HsUntypedSplice _ _ _ (L _ expr) ->
      case expr of
        HsApp _ (L _ l) (L _ r) ->
          case l of
            HsVar NoExt (L _ (Unqual (occNameString -> "makeLenses"))) ->
              True
            _ -> False
        _ -> False
    _ -> False
isMakeLensesThSplice _ = False
getMakeLensesSplices :: HsDecl GhcPs -> [String]
getMakeLensesSplices (SpliceD _ (SpliceDecl _ (L _ splice) _)) =
  case splice of
    HsUntypedSplice _ _ _ (L _ expr) ->
      case expr of
        HsApp _ (L _ l) (L _ r) ->
          case l of
            HsVar NoExt (L _ (Unqual (occNameString -> "makeLenses"))) ->
              case r of
                HsBracket NoExt (VarBr NoExt False (Unqual (occNameString -> typ))) ->
                  [typ]
                _ -> []
            _ -> []
        _ -> []
    _ -> []
getMakeLensesSplices _ = []
mkVar :: String -> HsExpr GhcPs
mkVar x = HsVar NoExt (mkName x)
type FieldName = String
type TypeName = String
genSigD
  :: FieldName
  -> TypeName
  
  -> HsType GhcPs
  
  -> HsDecl GhcPs
genSigD fieldName innerType outerType =
  SigD NoExt (TypeSig NoExt [ mkName fieldName ] hsWc)
    where
      hsWc = HsWC NoExt hsIb
      hsIb = HsIB NoExt (emptyL result)
      result = tyVarLens `appTy` tyVarTypeInner `appTy` outerType
      tyVarTypeInner = tyVar innerType
appTy :: HsType GhcPs -> HsType GhcPs -> HsType GhcPs
appTy = HsAppTy NoExt `on` emptyL
tyVarLens :: HsType GhcPs
tyVarLens = tyC "Lens'"
tyVar :: String -> HsType GhcPs
tyVar s = HsTyVar NoExt NotPromoted (mkTyVarName s)
tyC :: String -> HsType GhcPs
tyC s = HsTyVar NoExt NotPromoted (mkTyCName s)
getDecls :: HsModule GhcPs -> [(TypeName, [(FieldName, HsType GhcPs)])]
getDecls mod = concatMap go $ fmap (^. located) (hsmodDecls mod)
  where
    go :: HsDecl GhcPs -> [(String, [(String,HsType GhcPs)])]
    go (TyClD NoExt d) = [(getDeclTypeName &&& getFieldAndTypeName) d]
    go _ = []
mkName :: String -> Located RdrName
mkName = emptyL . mkRdrUnqual . mkOccName OccName.varName
mkTyVarName :: String -> Located RdrName
mkTyVarName = emptyL . mkRdrUnqual . mkOccName OccName.tcName
mkTyCName :: String -> Located RdrName
mkTyCName = emptyL . mkRdrUnqual . mkOccName OccName.tcName
getDeclTypeName :: TyClDecl GhcPs -> String
getDeclTypeName DataDecl {..} =
  case tcdLName ^. located of
    Unqual (occNameString -> s) -> s
getDeclTypeName _ = mempty
getFieldAndTypeName :: TyClDecl GhcPs -> [(String,HsType GhcPs)]
getFieldAndTypeName DataDecl {..} = concat . concat $
  dd_cons tcdDataDefn <&> \(L _ ConDeclH98 {..}) ->
    case con_args of
      RecCon (L _ xs) ->
        xs <&> \(L _ ConDeclField{..}) ->
          case cd_fld_names of
            [ L _ FieldOcc {..} ] ->
              case rdrNameFieldOcc of
                L _ (Unqual fieldName) ->
                  pure (occNameString fieldName, cd_fld_type ^. located)
                _ -> []
            _ -> []
      _ -> []
getFieldAndTypeName _ = []
genLensCall
  :: String
  -> (String, HsType GhcPs)
  -> [HsDecl GhcPs]
genLensCall lensInnerType (fieldName, fieldType) =
  [ genSigD lensName lensInnerType fieldType, valD (funBind lensName mg) ]
  where
    lensName = drop 1 fieldName
    mg =
      matchGroup
      [ match (funRhs lensName) $ grhss
        [ grhs $
          (hsVar "lens" `hsApp` hsVar fieldName)
            `hsApp`
               (hsPar
                 $ hsLam
                 $ matchGroup
#if MIN_VERSION_base (4,13,0)
                 [ lambdaMatch [ varPat "r", varPat "f" ] $
#else
                 [ lambdaMatch [ emptyL (varPat "r"), emptyL (varPat "f") ] $
#endif
                   grhss
                   [ grhs $ recordUpd (hsVar "r")
                     [ hsRecUpdField fieldName (hsVar "f")
                     ]
                   ]
                 ]
               )
        ]
      ]
valD :: HsBind GhcPs -> HsDecl GhcPs
valD = ValD NoExt
funBind :: String -> MatchGroup GhcPs (LHsExpr GhcPs) -> HsBind GhcPs
funBind s mg = FunBind NoExt (mkName s) mg WpHole []
matchGroup :: [Match GhcPs (LHsExpr GhcPs)] -> MatchGroup GhcPs (LHsExpr GhcPs)
matchGroup xs = MG NoExt (emptyL (fmap emptyL xs)) FromSource
match
  :: HsMatchContext (NameOrRdrName (IdP GhcPs)) 
  -> GRHSs GhcPs (LHsExpr GhcPs)
  -> Match GhcPs (LHsExpr GhcPs)
match x y = Match NoExt x [] y
lambdaMatch
  :: [LPat GhcPs]
  -> GRHSs GhcPs (LHsExpr GhcPs)
  -> Match GhcPs (LHsExpr GhcPs)
lambdaMatch xs y = Match NoExt LambdaExpr xs y
funRhs :: String -> HsMatchContext (NameOrRdrName (IdP GhcPs))
funRhs x = FunRhs (mkName x) Prefix NoSrcStrict
emptyL :: e -> GenLocated SrcSpan e
emptyL = L noSrcSpan
grhss :: [GRHS GhcPs (LHsExpr GhcPs)] -> GRHSs GhcPs (LHsExpr GhcPs)
grhss xs = GRHSs NoExt (fmap emptyL xs) (emptyL (EmptyLocalBinds NoExt))
grhs :: HsExpr GhcPs -> GRHS GhcPs (LHsExpr GhcPs)
grhs = GRHS NoExt [] . emptyL
hsApp :: HsExpr GhcPs -> HsExpr GhcPs -> HsExpr GhcPs
hsApp l r = HsApp NoExt (emptyL l) (emptyL r)
hsVar :: String -> HsExpr GhcPs
hsVar = HsVar NoExt . mkName
hsPar :: HsExpr GhcPs -> HsExpr GhcPs
hsPar = HsPar NoExt . emptyL
hsLam :: MatchGroup GhcPs (LHsExpr GhcPs) -> HsExpr GhcPs
hsLam = HsLam NoExt
recordUpd :: HsExpr GhcPs -> [HsRecUpdField GhcPs] -> HsExpr GhcPs
recordUpd e fs = RecordUpd NoExt (emptyL e) (emptyL <$> fs)
hsRecUpdField
  :: String
  -> HsExpr GhcPs
  -> HsRecUpdField GhcPs
hsRecUpdField s e = HsRecField (emptyL (ambig s)) (emptyL e) False
  where
    ambig :: String -> AmbiguousFieldOcc GhcPs
    ambig s = Unambiguous NoExt (mkName s)
varPat :: String -> Pat GhcPs
varPat = VarPat NoExt . mkName