module Pinchot.SyntaxTree.Optics where
import Data.Coerce (coerce)
import Data.Foldable (toList)
import Data.Maybe (catMaybes)
import Data.Sequence (Seq)
import qualified Control.Lens as Lens
import qualified Language.Haskell.TH as T
import qualified Language.Haskell.TH.Syntax as Syntax
import Pinchot.Rules
import Pinchot.Types
import Pinchot.Intervals
rulesToOptics
:: Syntax.Lift t
=> Qualifier
-> T.Name
-> Seq (Rule t)
-> T.Q [T.Dec]
rulesToOptics qual termName
= fmap concat
. traverse (ruleToOptics qual termName)
. families
ruleToOptics
:: Syntax.Lift t
=> Qualifier
-> T.Name
-> Rule t
-> T.Q [T.Dec]
ruleToOptics qual termName (Rule nm _ ty) = case ty of
Terminal ivls -> terminalToOptics qual termName nm ivls
NonTerminal b1 bs -> sequence $ nonTerminalToOptics qual nm b1 bs
Record sq -> sequence $ recordsToOptics qual nm sq
_ -> return []
terminalToOptics
:: Syntax.Lift t
=> Qualifier
-> T.Name
-> String
-> Intervals t
-> T.Q [T.Dec]
terminalToOptics qual termName nm ivls = do
e1 <- T.sigD (T.mkName ('_':nm))
$ T.forallT [ T.PlainTV (T.mkName "a")] (return [])
[t| Lens.Prism' ( $(T.conT termName), $(anyType) )
($(T.conT (quald qual nm)) $(T.conT termName) $(anyType))
|]
e2 <- T.valD prismName (T.normalB expn) []
return [e1, e2]
where
anyType = T.varT (T.mkName "a")
charType = T.varT (T.mkName "t")
prismName = T.varP (T.mkName ('_' : nm))
fetchPat = T.conP (quald qual nm) [T.varP (T.mkName "_x")]
fetchName = T.varE (T.mkName "_x")
ctor = T.conE (quald qual nm)
expn = [| let fetch $fetchPat = $fetchName
store (term, a)
| inIntervals ivls term = Just ($(ctor) (term, a))
| otherwise = Nothing
in Lens.prism' fetch store
|]
nonTerminalToOptics
:: Qualifier
-> String
-> Branch t
-> Seq (Branch t)
-> [T.Q T.Dec]
nonTerminalToOptics qual nm b1 bsSeq
= concat $ makePrism b1 : fmap makePrism bs
where
bs = toList bsSeq
makePrism (Branch inner rulesSeq) = [ signature, binding ]
where
charType = T.varT (T.mkName "t")
anyType = T.varT (T.mkName "a")
rules = toList rulesSeq
prismName = T.mkName ('_' : inner)
signature = T.sigD prismName
(forallA [t| Lens.Prism'
($(T.conT (quald qual nm)) $(charType) $(anyType))
$(fieldsType) |])
where
fieldsType = case rules of
[] -> T.tupleT 0
Rule r1 _ _ : [] -> [t| $(T.conT (quald qual r1))
$(charType) $(anyType) |]
rs -> foldl addType (T.tupleT (length rs)) rs
where
addType soFar (Rule r _ _) = soFar `T.appT`
[t| $(T.conT (quald qual r)) $(charType) $(anyType) |]
binding = T.valD (T.varP prismName) body []
where
body = T.normalB
$ (T.varE 'Lens.prism)
`T.appE` setter
`T.appE` getter
where
setter = T.lamE [pat] expn
where
(pat, expn) = case rules of
[] -> (T.tupP [], T.conE (quald qual inner))
_ : [] -> (T.varP local,
T.conE (quald qual inner)
`T.appE` T.varE local)
where
local = T.mkName "_x"
ls -> (T.tupP pats, set)
where
pats = fmap (\i -> T.varP (T.mkName ("_x" ++ show i)))
. take (length ls) $ [(0 :: Int) ..]
set = foldl addVar start . take (length ls)
$ [(0 :: Int) ..]
where
addVar acc i = acc `T.appE`
(T.varE (T.mkName ("_x" ++ show i)))
start = T.conE (quald qual inner)
getter = T.lamE [pat] expn
where
local = T.mkName "_x"
pat = T.varP local
expn = T.caseE (T.varE (T.mkName "_x")) $
T.match patCtor bodyCtor []
: rest
where
patCtor = T.conP (quald qual inner)
. fmap (\i -> T.varP (T.mkName $ "_y" ++ show i))
. take (length rules)
$ [(0 :: Int) ..]
bodyCtor = T.normalB . (T.conE 'Right `T.appE`)
$ case rules of
[] -> T.tupE []
_:[] -> T.varE (T.mkName "_y0")
_ -> T.tupE
. fmap (\i -> T.varE (T.mkName $ "_y" ++ show i))
. take (length rules)
$ [(0 :: Int) ..]
rest = case bs of
[] -> []
_ -> [T.match patBlank bodyBlank []]
where
patBlank = T.varP (T.mkName "_z")
bodyBlank = T.normalB
$ T.conE ('Left)
`T.appE` T.varE (T.mkName "_z")
recordsToOptics
:: Qualifier
-> String
-> Seq (Rule t)
-> [T.Q T.Dec]
recordsToOptics qual nm
= concat . zipWith makeLens [(0 :: Int) ..] . toList
where
makeLens index (Rule inner _ _) = [ signature, function ]
where
charType = T.varT (T.mkName "t")
anyType = T.varT (T.mkName "a")
fieldNm = recordFieldName index nm inner
lensName = T.mkName fieldNm
signature = T.sigD lensName (forallA
[t| Lens.Lens' ($(T.conT (quald qual nm)) $(charType) $(anyType))
($(T.conT (quald qual inner)) $(charType) $(anyType))
|])
function = T.funD lensName [T.clause [] (T.normalB body) []]
where
namedRec = T.mkName "_namedRec"
namedNewVal = T.mkName "_namedNewVal"
body = (T.varE 'Lens.lens) `T.appE` getter `T.appE` setter
where
getter = T.lamE [pat] expn
where
pat = T.varP namedRec
expn = (T.varE (quald qual ('_' : fieldNm)))
`T.appE` (T.varE namedRec)
setter = T.lamE [patRec, patNewVal] expn
where
patRec = T.varP namedRec
patNewVal = T.varP namedNewVal
expn = T.recUpdE (T.varE namedRec)
[ return ( quald qual ('_' : fieldNm)
, T.VarE namedNewVal) ]
forallA :: T.TypeQ -> T.TypeQ
forallA = T.forallT [ T.PlainTV (T.mkName "t")
, T.PlainTV (T.mkName "a")] (return [])