{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeApplications #-}
{-# OPTIONS_GHC -Wno-unused-imports #-}
{-# OPTIONS_GHC -split-sections #-}

module TypedSession.TH where

import Control.Monad (forM)
import Data.Either (fromRight)
import Data.Kind
import qualified Data.Set as Set
import Language.Haskell.TH hiding (Type)
import qualified Language.Haskell.TH as TH
import Language.Haskell.TH.Quote
import TypedSession.State.GenDoc
import TypedSession.State.Parser (runProtocolParser)
import TypedSession.State.Piple (PipleResult (..), piple)
import TypedSession.State.Render
import TypedSession.State.Type (BranchSt (..), Creat, MsgOrLabel (..), MsgT1, Protocol (..), ProtocolError, T (..))

roleDecs :: Name -> Q [Dec]
roleDecs :: Name -> Q [Dec]
roleDecs Name
name = do
  res <- Name -> Q Info
reify Name
name
  case res of
    TyConI (DataD [] Name
dName [] Maybe Pred
Nothing [Con]
cons [DerivClause]
_) -> do
      a <- String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"a"
      x <- newName "x"
      pure $
        [ DataD
            []
            (addS dName)
            [KindedTV a BndrReq (ConT dName)]
            Nothing
            [GadtC [addS n] [] (AppT (ConT (addS dName)) (PromotedT n)) | NormalC n [] <- cons]
            []
        , TySynInstD (TySynEqn Nothing (ConT (mkName "Data.IFunctor.Sing")) (ConT (addS dName)))
        ]
          ++ [ InstanceD
                Nothing
                []
                (AppT (ConT $ mkName "Data.IFunctor.SingI") (PromotedT n))
                [FunD (mkName "sing") [Clause [] (NormalB (ConE (addS n))) []]]
             | NormalC n [] <- cons
             ]
          ++ [ InstanceD
                Nothing
                []
                (AppT (ConT (mkName "TypedSession.Core.SingToInt")) (ConT name))
                [ FunD
                    (mkName "singToInt")
                    [Clause [VarP x] (NormalB (AppE (ConE (mkName "I#")) (AppE (VarE $ mkName "dataToTag#") (VarE x)))) []]
                ]
             ]
    Info
_ -> String -> Q [Dec]
forall a. HasCallStack => String -> a
error String
"np"

addS :: Name -> Name
addS :: Name -> Name
addS Name
name =
  let n :: String
n = (Name -> String
nameBase Name
name)
   in String -> Name
mkName (String
"S" String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
n)

protDecsAndMsgDecs :: forall r bst. (Show r, Show bst) => String -> Name -> Name -> PipleResult r bst -> Q [Dec]
protDecsAndMsgDecs :: forall r bst.
(Show r, Show bst) =>
String -> Name -> Name -> PipleResult r bst -> Q [Dec]
protDecsAndMsgDecs String
protN Name
roleName Name
bstName PipleResult{Protocol (MsgT1 r bst) r bst
msgT1 :: Protocol (MsgT1 r bst) r bst
msgT1 :: forall r bst. PipleResult r bst -> Protocol (MsgT1 r bst) r bst
msgT1, Set Int
dnySet :: Set Int
dnySet :: forall r bst. PipleResult r bst -> Set Int
dnySet, stBound :: forall r bst. PipleResult r bst -> (Int, Int)
stBound = (Int
fromVal, Int
toVal)} = do
  sVar <- String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"s"
  let protName = String -> Name
mkName String
protN
      protSName = String -> Name
mkName (String
"S" String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
protN)
      genConstr Int
i =
        if Int
i Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== -Int
1
          then Name -> [BangType] -> Con
NormalC (String -> Name
mkName String
"End") []
          else
            if Int
i Int -> Set Int -> Bool
forall a. Ord a => a -> Set a -> Bool
`Set.member` Set Int
dnySet
              then
                Name -> [BangType] -> Con
NormalC
                  (String -> Name
mkName (String -> Name) -> String -> Name
forall a b. (a -> b) -> a -> b
$ String
"S" String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Int -> String
forall a. Show a => a -> String
show Int
i)
                  [(SourceUnpackedness -> SourceStrictness -> Bang
Bang SourceUnpackedness
NoSourceUnpackedness SourceStrictness
NoSourceStrictness, Name -> Pred
ConT Name
bstName)]
              else Name -> [BangType] -> Con
NormalC (String -> Name
mkName (String -> Name) -> String -> Name
forall a b. (a -> b) -> a -> b
$ String
"S" String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Int -> String
forall a. Show a => a -> String
show Int
i) []
      genSConstr Int
i =
        if Int
i Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== -Int
1
          then
            [Name] -> [BangType] -> Pred -> Con
GadtC [String -> Name
mkName String
"SEnd"] [] (Pred -> Pred -> Pred
AppT (Name -> Pred
ConT Name
protSName) (Name -> Pred
PromotedT (String -> Name
mkName String
"End")))
          else
            if Int
i Int -> Set Int -> Bool
forall a. Ord a => a -> Set a -> Bool
`Set.member` Set Int
dnySet
              then
                [TyVarBndr Specificity] -> [Pred] -> Con -> Con
ForallC
                  [Name -> Specificity -> Pred -> TyVarBndr Specificity
forall flag. Name -> flag -> Pred -> TyVarBndr flag
KindedTV Name
sVar Specificity
SpecifiedSpec (Name -> Pred
ConT Name
bstName)]
                  []
                  ( [Name] -> [BangType] -> Pred -> Con
GadtC
                      [String -> Name
mkName (String -> Name) -> String -> Name
forall a b. (a -> b) -> a -> b
$ String
"SS" String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Int -> String
forall a. Show a => a -> String
show Int
i]
                      []
                      (Pred -> Pred -> Pred
AppT (Name -> Pred
ConT Name
protSName) (Pred -> Pred -> Pred
AppT (Name -> Pred
PromotedT (Name -> Pred) -> Name -> Pred
forall a b. (a -> b) -> a -> b
$ String -> Name
mkName (String -> Name) -> String -> Name
forall a b. (a -> b) -> a -> b
$ String
"S" String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Int -> String
forall a. Show a => a -> String
show Int
i) (Name -> Pred
VarT Name
sVar)))
                  )
              else
                [Name] -> [BangType] -> Pred -> Con
GadtC [String -> Name
mkName (String -> Name) -> String -> Name
forall a b. (a -> b) -> a -> b
$ String
"SS" String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Int -> String
forall a. Show a => a -> String
show Int
i] [] (Pred -> Pred -> Pred
AppT (Name -> Pred
ConT Name
protSName) (Name -> Pred
PromotedT (Name -> Pred) -> Name -> Pred
forall a b. (a -> b) -> a -> b
$ String -> Name
mkName (String -> Name) -> String -> Name
forall a b. (a -> b) -> a -> b
$ String
"S" String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Int -> String
forall a. Show a => a -> String
show Int
i))

      isTAny :: T bst -> Bool
      isTAny = \case
        TAny Int
_ -> Bool
True
        T bst
_ -> Bool
False

      mkInstanceMsg :: Name -> Protocol (MsgT1 r bst) r bst -> Q [Con]
      mkInstanceMsg Name
s = \case
        Msg ((T bst
a, T bst
b, T bst
c), (r
from, r
to), Int
_) String
constr [String]
args r
_ r
_ :> Protocol (MsgT1 r bst) r bst
prots -> do
          let tAnyToType :: T bst -> TH.Type
              tAnyToType :: T bst -> Pred
tAnyToType = \case
                TNum Int
i -> Name -> Pred
PromotedT (String -> Name
mkName (String -> Name) -> String -> Name
forall a b. (a -> b) -> a -> b
$ String
"S" String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Int -> String
forall a. Show a => a -> String
show Int
i)
                BstList Int
i bst
bst -> Pred -> Pred -> Pred
AppT (Name -> Pred
PromotedT (String -> Name
mkName (String -> Name) -> String -> Name
forall a b. (a -> b) -> a -> b
$ String
"S" String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Int -> String
forall a. Show a => a -> String
show Int
i)) (Name -> Pred
PromotedT (String -> Name
mkName (bst -> String
forall a. Show a => a -> String
show bst
bst)))
                TAny Int
i -> Pred -> Pred -> Pred
AppT (Name -> Pred
PromotedT (String -> Name
mkName (String -> Name) -> String -> Name
forall a b. (a -> b) -> a -> b
$ String
"S" String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Int -> String
forall a. Show a => a -> String
show Int
i)) (Name -> Pred
VarT Name
s)
                T bst
TEnd -> Name -> Pred
PromotedT (Name -> Pred) -> Name -> Pred
forall a b. (a -> b) -> a -> b
$ String -> Name
mkName String
"End"

          let mkTName :: Pred
mkTName =
                ( Pred -> Pred -> Pred
AppT
                    ( Pred -> Pred -> Pred
AppT
                        ( Pred -> Pred -> Pred
AppT
                            ( Pred -> Pred -> Pred
AppT
                                ( Pred -> Pred -> Pred
AppT
                                    (Name -> Pred
ConT (Name -> Pred) -> Name -> Pred
forall a b. (a -> b) -> a -> b
$ String -> Name
mkName String
"Msg")
                                    (Name -> Pred
ConT Name
roleName)
                                )
                                (Name -> Pred
ConT Name
protName)
                            )
                            (T bst -> Pred
tAnyToType T bst
a)
                        )
                        (Pred -> Pred -> Pred
AppT (Pred -> Pred -> Pred
AppT (Int -> Pred
PromotedTupleT Int
2) (Name -> Pred
PromotedT (String -> Name
mkName (r -> String
forall a. Show a => a -> String
show r
from)))) (T bst -> Pred
tAnyToType T bst
b))
                    )
                    (Pred -> Pred -> Pred
AppT (Pred -> Pred -> Pred
AppT (Int -> Pred
PromotedTupleT Int
2) (Name -> Pred
PromotedT (String -> Name
mkName (r -> String
forall a. Show a => a -> String
show r
to)))) (T bst -> Pred
tAnyToType T bst
c))
                )
          let val :: Con
val =
                let gadtc :: Con
gadtc =
                      [Name] -> [BangType] -> Pred -> Con
GadtC
                        [String -> Name
mkName String
constr]
                        [ ( SourceUnpackedness -> SourceStrictness -> Bang
Bang SourceUnpackedness
NoSourceUnpackedness SourceStrictness
NoSourceStrictness
                          , case String -> [String]
words String
ag of
                              [] -> String -> Pred
forall a. HasCallStack => String -> a
error String
"np"
                              (String
x : [String]
xs) -> (Pred -> Pred -> Pred) -> Pred -> [Pred] -> Pred
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' Pred -> Pred -> Pred
AppT (Name -> Pred
ConT (String -> Name
mkName String
x)) ((String -> Pred) -> [String] -> [Pred]
forall a b. (a -> b) -> [a] -> [b]
map (Name -> Pred
ConT (Name -> Pred) -> (String -> Name) -> String -> Pred
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Name
mkName) [String]
xs)
                          )
                        | String
ag <- [String]
args
                        ]
                        Pred
mkTName
                 in if (T bst -> Bool) -> [T bst] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any T bst -> Bool
isTAny [T bst
a, T bst
b, T bst
c]
                      then
                        [TyVarBndr Specificity] -> [Pred] -> Con -> Con
ForallC
                          [Name -> Specificity -> Pred -> TyVarBndr Specificity
forall flag. Name -> flag -> Pred -> TyVarBndr flag
KindedTV Name
s Specificity
SpecifiedSpec (Name -> Pred
ConT Name
bstName)]
                          []
                          Con
gadtc
                      else Con
gadtc
          res <- Name -> Protocol (MsgT1 r bst) r bst -> Q [Con]
mkInstanceMsg Name
s Protocol (MsgT1 r bst) r bst
prots
          pure (val : res)
        Label XLabel (MsgT1 r bst)
_ Int
_ :> Protocol (MsgT1 r bst) r bst
prots -> Name -> Protocol (MsgT1 r bst) r bst -> Q [Con]
mkInstanceMsg Name
s Protocol (MsgT1 r bst) r bst
prots
        Branch XBranch (MsgT1 r bst)
_ r
_ [BranchSt (MsgT1 r bst) r bst]
ls -> do
          ls' <- [BranchSt (MsgT1 r bst) r bst]
-> (BranchSt (MsgT1 r bst) r bst -> Q [Con]) -> Q [[Con]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [BranchSt (MsgT1 r bst) r bst]
ls ((BranchSt (MsgT1 r bst) r bst -> Q [Con]) -> Q [[Con]])
-> (BranchSt (MsgT1 r bst) r bst -> Q [Con]) -> Q [[Con]]
forall a b. (a -> b) -> a -> b
$ \(BranchSt XBranchSt (MsgT1 r bst)
_ bst
_ Protocol (MsgT1 r bst) r bst
prot) -> Name -> Protocol (MsgT1 r bst) r bst -> Q [Con]
mkInstanceMsg Name
s Protocol (MsgT1 r bst) r bst
prot
          pure $ concat ls'
        Protocol (MsgT1 r bst) r bst
_ -> [Con] -> Q [Con]
forall a. a -> Q a
forall (f :: * -> *) a. Applicative f => a -> f a
pure []

  a <- newName "a"
  s1 <- newName "s1"
  x <- newName "x"

  -- make instance Done
  res <- reify roleName
  instDoneDesc <- case res of
    TyConI (DataD [] Name
_ [] Maybe Pred
Nothing [Con]
cons [DerivClause]
_) -> do
      [Dec] -> Q [Dec]
forall a. a -> Q a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
        [ TySynEqn -> Dec
TySynInstD
            ( Maybe [TyVarBndr ()] -> Pred -> Pred -> TySynEqn
TySynEqn
                Maybe [TyVarBndr ()]
forall a. Maybe a
Nothing
                (Pred -> Pred -> Pred
AppT (Name -> Pred
ConT (String -> Name
mkName String
"Done")) (Name -> Pred
PromotedT Name
n))
                (Name -> Pred
PromotedT (String -> Name
mkName String
"End"))
            )
        | NormalC Name
n [] <- [Con]
cons
        ]
    Info
_ -> String -> Q [Dec]
forall a. HasCallStack => String -> a
error String
"np"

  -- make instance msg
  ss <- newName "s"
  instMsgDesc <- mkInstanceMsg ss (msgT1)

  fromVar <- newName "from"
  sendVar <- newName "send"
  recvVar <- newName "recv"
  pure $
    [ DataD [] protName [] Nothing [genConstr i | i <- [fromVal .. toVal]] []
    ]
      ++ [DataD [] protSName [KindedTV a BndrReq (ConT protName)] Nothing [genSConstr i | i <- [fromVal .. toVal]] []]
      ++ [TySynInstD (TySynEqn Nothing (ConT (mkName "Data.IFunctor.Sing")) (ConT protSName))]
      ++ [ InstanceD
            Nothing
            []
            ( AppT
                (ConT $ mkName "Data.IFunctor.SingI")
                ( if i == -1
                    then PromotedT (mkName "End")
                    else
                      if i `Set.member` dnySet
                        then SigT (AppT (PromotedT (mkName ("S" <> show i))) (VarT s1)) (ConT protName)
                        else PromotedT (mkName ("S" <> show i))
                )
            )
            [ FunD
                (mkName "sing")
                [Clause [] (NormalB (ConE (mkName $ "S" <> (if i == -1 then "End" else ("S" <> show i))))) []]
            ]
         | i <- [fromVal .. toVal]
         ]
      ++ [ InstanceD
            Nothing
            []
            (AppT (ConT (mkName "TypedSession.Core.SingToInt")) (ConT protName))
            [ FunD
                (mkName "singToInt")
                [Clause [VarP x] (NormalB (AppE (ConE (mkName "I#")) (AppE (VarE $ mkName "dataToTag#") (VarE x)))) []]
            ]
         ]
      ++ [ InstanceD
            Nothing
            []
            ( AppT
                (AppT (ConT (mkName "TypedSession.Core.Protocol")) (ConT roleName))
                (ConT protName)
            )
            ( instDoneDesc
                ++ let ct1 = (Pred -> Pred -> Pred
AppT (Pred -> Pred -> Pred
AppT (Int -> Pred
TupleT Int
2) (Name -> Pred
ConT Name
roleName)) (Name -> Pred
ConT Name
protName))
                    in [ DataInstD
                          []
                          ( Just
                              [ KindedTV fromVar () (ConT protName)
                              , KindedTV sendVar () ct1
                              , KindedTV recvVar () ct1
                              ]
                          )
                          ( AppT
                              ( AppT
                                  ( AppT
                                      ( AppT
                                          ( AppT
                                              (ConT $ mkName "Msg")
                                              (ConT roleName)
                                          )
                                          (ConT protName)
                                      )
                                      (SigT (VarT fromVar) (ConT protName))
                                  )
                                  (SigT (VarT sendVar) ct1)
                              )
                              (SigT (VarT recvVar) ct1)
                          )
                          Nothing
                          instMsgDesc
                          []
                       ]
            )
         ]

protocol
  :: forall r bst
   . ( Enum r
     , Bounded r
     , Show r
     , Enum bst
     , Bounded bst
     , Show bst
     , Ord r
     )
  => String -> Name -> Name -> QuasiQuoter
protocol :: forall r bst.
(Enum r, Bounded r, Show r, Enum bst, Bounded bst, Show bst,
 Ord r) =>
String -> Name -> Name -> QuasiQuoter
protocol String
protN Name
roleName Name
bstName =
  QuasiQuoter
    { quoteExp :: String -> Q Exp
quoteExp = Q Exp -> String -> Q Exp
forall a b. a -> b -> a
const (Q Exp -> String -> Q Exp) -> Q Exp -> String -> Q Exp
forall a b. (a -> b) -> a -> b
$ String -> Q Exp
forall a. String -> Q a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"No protocol parse for exp"
    , quotePat :: String -> Q Pat
quotePat = Q Pat -> String -> Q Pat
forall a b. a -> b -> a
const (Q Pat -> String -> Q Pat) -> Q Pat -> String -> Q Pat
forall a b. (a -> b) -> a -> b
$ String -> Q Pat
forall a. String -> Q a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"No protocol parse for pat"
    , quoteType :: String -> Q Pred
quoteType = Q Pred -> String -> Q Pred
forall a b. a -> b -> a
const (Q Pred -> String -> Q Pred) -> Q Pred -> String -> Q Pred
forall a b. (a -> b) -> a -> b
$ String -> Q Pred
forall a. String -> Q a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"No protocol parser for type"
    , quoteDec :: String -> Q [Dec]
quoteDec = String -> Q [Dec]
parseOrThrow
    }
 where
  parseOrThrow :: String -> Q [Dec]
  parseOrThrow :: String -> Q [Dec]
parseOrThrow String
st = case forall r bst.
(Enum r, Enum bst, Bounded r, Bounded bst, Show r, Show bst) =>
String -> Either String (Protocol Creat r bst)
runProtocolParser @r @bst String
st of
    Left String
e -> String -> Q [Dec]
forall a. String -> Q a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> String
forall a. Show a => a -> String
show String
e)
    Right Protocol Creat r bst
protCreat -> case Protocol Creat r bst
-> Either (ProtocolError r bst) (PipleResult r bst)
forall r bst.
(Enum r, Bounded r, Eq r, Ord r) =>
Protocol Creat r bst
-> Either (ProtocolError r bst) (PipleResult r bst)
piple Protocol Creat r bst
protCreat of
      Left ProtocolError r bst
e -> String -> Q [Dec]
forall a. String -> Q a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (ProtocolError r bst -> String
forall a. Show a => a -> String
show ProtocolError r bst
e)
      Right PipleResult r bst
pipResult -> do
        let graphStr :: String
graphStr = forall r bst.
(Enum r, Bounded r, Show bst, Ord r, Show r) =>
StrFillEnv -> PipleResult r bst -> String
genGraph @r @bst StrFillEnv
defaultStrFilEnv PipleResult r bst
pipResult
        IO () -> Q ()
forall a. IO a -> Q a
runIO (IO () -> Q ()) -> IO () -> Q ()
forall a b. (a -> b) -> a -> b
$ do
          String -> String -> IO ()
writeFile (String
protN String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
".prot") String
graphStr
          String -> IO ()
putStrLn String
graphStr
        d1 <- Name -> Q [Dec]
roleDecs Name
roleName
        d2 <- protDecsAndMsgDecs protN roleName bstName pipResult
        pure (d1 ++ d2)