{-# LANGUAGE TemplateHaskell #-} module Sums.Internal where import Data.Foldable (foldl') import qualified Language.Haskell.TH as T import qualified Control.Lens as Lens -- | Given the number of constructors, create a new data type. -- Derives Eq, Ord, and Show, except for S0, which has no derives. sumDeclaration :: Int -> T.Dec sumDeclaration i = T.DataD cxt name types Nothing ctors derives where cxt = [] name = T.mkName ("S" ++ show i) types = map mkType [1..i] where mkType n = T.PlainTV (T.mkName ("t" ++ show n)) ctors = map mkCtor [1..i] where mkCtor n = T.NormalC (T.mkName ("S" ++ show i ++ "_" ++ show n)) [ ( T.Bang T.NoSourceUnpackedness T.NoSourceStrictness , T.VarT (T.mkName ("t" ++ show n)) ) ] derives | i == 0 = [] | otherwise = map T.ConT [''Eq, ''Ord, ''Show] -- | Given the number of constructors and this particular constructor -- number, return a type signature for a prism. prismSig :: Int -> Int -> T.DecQ prismSig nCtors thisCtor = T.sigD name ty where name = T.mkName ("_S" ++ show nCtors ++ "_" ++ show thisCtor) ty = [t| Lens.Prism $(big) $(big') $(little) $(little') |] var = T.varT . T.mkName little = var ("t" ++ show thisCtor) little' = var ("t" ++ show thisCtor ++ "'") mkTypes maker = foldl' T.appT start tys where start = T.conT (T.mkName ("S" ++ show nCtors)) tys = map maker [1..nCtors] big = mkTypes (\n -> var ("t" ++ show n)) big' = mkTypes f where f n | n /= thisCtor = var ("t" ++ show n) | otherwise = var ("t" ++ show n ++ "'") -- | Given the number of constructors and the particular constructor -- number, return the prism itself. prismDecl :: Int -> Int -> T.DecQ prismDecl nCtors thisCtor = T.valD prismPat prismBody [] where otherCtorName n = T.mkName ("S" ++ show nCtors ++ "_" ++ show n) thisCtorName = otherCtorName thisCtor prismPat = T.varP (T.mkName ("_S" ++ show nCtors ++ "_" ++ show thisCtor)) prismBody = T.normalB expn where expn = [| Lens.prism $(make) $(decon) |] where make = T.conE thisCtorName decon = do x <- T.newName "x" let caseExpn = T.caseE (T.varE x) matches T.lam1E (T.varP x) caseExpn matches = found : notFounds found = do x <- T.newName "x" let pat = T.conP thisCtorName [T.varP x] body = T.normalB [| Right $(T.varE x) |] T.match pat body [] notFounds = map mkNotFound . filter (/= thisCtor) $ [1..nCtors] mkNotFound i = do x <- T.newName "x" let pat = T.conP (otherCtorName i) [T.varP x] body = T.normalB [| Left $ $(T.conE (otherCtorName i)) $(T.varE x) |] T.match pat body [] -- | Given the number of ctors and this ctor number, return a -- signature and the prism itself. prismSigAndDecl :: Int -> Int -> T.DecsQ prismSigAndDecl nCtors thisCtor = sequence [ prismSig nCtors thisCtor , prismDecl nCtors thisCtor ] -- | Given the number of ctors, return all prisms. prismsForSingleType :: Int -> T.DecsQ prismsForSingleType i = fmap concat . traverse (prismSigAndDecl i) $ [1..i]