{-# LANGUAGE TemplateHaskellQuotes #-} {-# LANGUAGE LambdaCase #-} ----------------------------------------------------------------------------- -- | -- Module : Each.Transform -- Copyright : (c) dramforever 2017 -- License : BSD3 -- -- Maintainer : dramforever -- Stability : experimental -- Portability : non-portable (Template Haskell) -- -- An internal module where most of the real transformation goes on. ----------------------------------------------------------------------------- module Each.Transform where import Control.Applicative import Control.Monad import Data.DList (DList, singleton, toList) import Data.Monoid import Language.Haskell.TH import qualified Each.Invoke -- | A writer monad where the empty case is distinguished. data Result a = Impure (DList Stmt) a -- ^ Invariant: the bind list is not empty | Pure a instance Functor Result where f `fmap` Impure bs x = Impure bs (f x) f `fmap` Pure a = Pure (f a) instance Applicative Result where pure = Pure Pure f <*> Pure x = Pure (f x) Pure f <*> Impure xbs x = Impure xbs (f x) Impure fbs f <*> Pure x = Impure fbs (f x) Impure fbs f <*> Impure xbs x = Impure (fbs <> xbs) (f x) instance Monad Result where Impure bs x >>= k = case k x of Impure ks r -> Impure (bs <> ks) r Pure r -> Impure bs r Pure x >>= k = k x addBind :: Name -> Exp -> Result () addBind n e = Impure (singleton (BindS (VarP n) e)) () -- | Invoke an 'each' block each :: ExpQ -> ExpQ each inp = generate <$> (inp >>= transform) generate :: Result Exp -> Exp generate (Pure x) = AppE (VarE 'Control.Applicative.pure) x generate (Impure xs x) = DoE $ toList (xs <> singleton ( NoBindS (AppE (VarE 'Control.Monad.return) x))) transform :: Exp -> Q (Result Exp) -- Detecting and processing invocations of bind transform (InfixE Nothing (VarE v) (Just x)) | v == '(Each.Invoke.~!) = impurify x transform (AppE (VarE v) x) | v == 'Each.Invoke.bind = impurify x transform (InfixE (Just (VarE vf)) (VarE vo) (Just x)) | vf == 'Each.Invoke.bind && vo == '(Prelude.$) = impurify x transform (VarE n) = pure $ pure (VarE n) transform (ConE n) = pure $ pure (ConE n) transform (LitE l) = pure $ pure (LitE l) transform (AppE f x) = liftA2 (liftA2 AppE) (transform f) (transform x) transform (InfixE lhs mid rhs) = do tl <- traverse transform lhs tm <- transform mid tr <- traverse transform rhs pure (liftA3 InfixE (sequence tl) tm (sequence tr)) -- TODO Maybe add checks to ensure that the arguments aren't used impurely? transform (LamE ps x) = fmap (LamE ps) <$> transform x transform (TupE ps) = fmap TupE . sequence <$> (traverse transform ps) transform (CondE c t f) = do tc <- transform c tt <- transform t tf <- transform f case liftA2 (,) tt tf of Pure (et, ef) -> pure $ (\z -> CondE z et ef) <$> tc res -> do var <- newName "bind" pure $ do ec <- tc addBind var (CondE ec (generate tt) (generate tf)) pure (VarE var) transform (MultiIfE bs) = case desugarMultiIf bs of Right x -> transform x Left err -> fail err where desugarMultiIf :: [(Guard, Exp)] -> Either String Exp desugarMultiIf [] = pure (AppE (VarE 'Prelude.error) (LitE $ StringL errNonExhaustiveGuard)) desugarMultiIf ((NormalG c, t) : bs) = go <$> desugarMultiIf bs where go f = CondE c t f desugarMultiIf ((PatG _, _) : _) = Left errPatternGuard transform (LetE [] e) = transform e transform (LetE (ValD p v [] : ds) e) = transform (CaseE (bodyToExp v) [Match p (NormalB $ LetE ds e) []]) transform (LetE (ValD _ _ _ : _) _) = fail errWhere transform (LetE _ _) = fail errComplexLet transform (CaseE s ma) = do ts <- transform s tm <- traverse transformMatch ma case traverse getPureMatch tm of Just pes -> pure $ (\z -> CaseE z (toMatch <$> pes)) <$> ts Nothing -> do var <- newName "bind" pure $ do es <- ts addBind var (CaseE es (generateMatch <$> tm)) pure (VarE var) where generateMatch :: (Pat, Result Exp) -> Match generateMatch (p, e) = toMatch (p, generate e) toMatch :: (Pat, Exp) -> Match toMatch (p, e) = Match p (NormalB e) [] getPureMatch :: (Pat, Result Exp) -> Maybe (Pat, Exp) getPureMatch (pat, Pure e) = Just (pat, e) getPureMatch _ = Nothing transformMatch :: Match -> Q (Pat, Result Exp) transformMatch (Match pat body []) = (\x -> (pat, x)) <$> transform (bodyToExp body) transformMatch _ = fail errWhere transform (ArithSeqE z) = fmap ArithSeqE <$> case z of FromR a -> fmap FromR <$> transform a FromThenR a b -> liftA2 (liftA2 FromThenR) (transform a) (transform b) FromToR a b -> liftA2 (liftA2 FromToR) (transform a) (transform b) FromThenToR a b c -> liftA3 (liftA3 FromThenToR) (transform a) (transform b) (transform c) transform (ListE xs) = fmap ListE . sequence <$> (traverse transform xs) transform (SigE e t) = fmap (\te -> SigE te t) <$> transform e transform (RecConE name fes) = fmap (RecConE name) . sequence <$> (traverse transformFieldExp fes) transform (RecUpdE x fes) = liftA2 (liftA2 RecUpdE) (transform x) (sequence <$> traverse transformFieldExp fes) transform (UnboundVarE n) = pure $ pure (UnboundVarE n) transform x = fail (errUnsupported <> pprint x) bodyToExp :: Body -> Exp bodyToExp (NormalB x) = x bodyToExp (GuardedB x) = MultiIfE x transformFieldExp :: FieldExp -> Q (Result FieldExp) transformFieldExp (nm, e) = fmap (\x -> (nm, x)) <$> transform e impurify :: Exp -> Q (Result Exp) impurify e = liftA2 go (transform e) (newName "bind") where go te nm = te >>= \z -> VarE nm <$ addBind nm z errNonExhaustiveGuard, errUnsupported, errPatternGuard, errWhere, errComplexLet :: String errNonExhaustiveGuard = "Non-exhaustive guard" errUnsupported = "Unsupported syntax in: " errPatternGuard = "Pattern guards are not supported" errWhere = "'where' is not supported" errComplexLet = "Only declarations like 'pattern = value' are supported in let"