module Stg.Language (
Program (..),
Binds (..),
LambdaForm (..),
prettyLambda,
UpdateFlag (..),
Rec (..),
Expr (..),
Alts (..),
NonDefaultAlts (..),
AlgebraicAlt (..),
PrimitiveAlt (..),
DefaultAlt (..),
Literal (..),
PrimOp (..),
Var (..),
Atom (..),
Constr (..),
Pretty (..),
classify,
LambdaType(..),
) where
import Control.DeepSeq
import Data.List.NonEmpty (NonEmpty (..))
import qualified Data.List.NonEmpty as NonEmpty
import Data.Map (Map)
import qualified Data.Map as M
import qualified Data.Semigroup as Semigroup
import Data.Text (Text)
import qualified Data.Text as T
import Data.Text.Prettyprint.Doc
import GHC.Exts
import GHC.Generics
import Language.Haskell.TH.Lift
import Stg.Language.Prettyprint
newtype Program = Program Binds
deriving (Eq, Ord, Show, Generic)
instance Monoid Program where
mempty = Program mempty
mappend = (Semigroup.<>)
instance Semigroup.Semigroup Program where
Program x <> Program y = Program (x <> y)
newtype Binds = Binds (Map Var LambdaForm)
deriving (Eq, Ord, Generic)
instance Monoid Binds where
mempty = Binds mempty
mappend = (Semigroup.<>)
instance Semigroup.Semigroup Binds where
Binds x <> Binds y = Binds (x <> y)
instance Show Binds where
show (Binds binds) = "(Binds " <> show (M.assocs binds) <> ")"
data LambdaForm = LambdaForm ![Var] !UpdateFlag ![Var] !Expr
deriving (Eq, Ord, Show, Generic)
data LambdaType =
LambdaCon
| LambdaFun
| LambdaThunk
deriving (Eq, Ord, Show)
instance PrettyStgi LambdaType where
prettyStgi = \case
LambdaCon -> "Con"
LambdaFun -> "Fun"
LambdaThunk -> "Thunk"
classify :: LambdaForm -> LambdaType
classify = \case
LambdaForm _ _ [] AppC{} -> LambdaCon
LambdaForm _ _ (_:_) _ -> LambdaFun
LambdaForm _ _ [] _ -> LambdaThunk
data UpdateFlag =
Update
| NoUpdate
deriving (Eq, Ord, Show, Generic, Enum, Bounded)
data Rec =
NonRecursive
| Recursive
deriving (Eq, Ord, Show, Generic, Enum, Bounded)
data Expr =
Let !Rec !Binds !Expr
| Case !Expr !Alts
| AppF !Var ![Atom]
| AppC !Constr ![Atom]
| AppP !PrimOp !Atom !Atom
| LitE !Literal
deriving (Eq, Ord, Show, Generic)
data Alts = Alts !NonDefaultAlts !DefaultAlt
deriving (Eq, Ord, Show, Generic)
data NonDefaultAlts =
NoNonDefaultAlts
| AlgebraicAlts !(NonEmpty AlgebraicAlt)
| PrimitiveAlts !(NonEmpty PrimitiveAlt)
deriving (Eq, Ord, Show, Generic)
data AlgebraicAlt = AlgebraicAlt !Constr ![Var] !Expr
deriving (Eq, Ord, Show, Generic)
data PrimitiveAlt = PrimitiveAlt !Literal !Expr
deriving (Eq, Ord, Show, Generic)
data DefaultAlt =
DefaultNotBound !Expr
| DefaultBound !Var !Expr
deriving (Eq, Ord, Show, Generic)
newtype Literal = Literal Integer
deriving (Eq, Ord, Show, Generic)
data PrimOp =
Add
| Sub
| Mul
| Div
| Mod
| Eq
| Lt
| Leq
| Gt
| Geq
| Neq
deriving (Eq, Ord, Show, Generic, Bounded, Enum)
newtype Var = Var Text
deriving (Eq, Ord, Show, Generic)
instance IsString Var where fromString = coerce . T.pack
data Atom =
AtomVar !Var
| AtomLit !Literal
deriving (Eq, Ord, Show, Generic)
newtype Constr = Constr Text
deriving (Eq, Ord, Show, Generic)
instance IsString Constr where fromString = coerce . T.pack
deriveLiftMany [ ''Program, ''Literal, ''LambdaForm, ''UpdateFlag, ''Rec
, ''Expr, ''Alts, ''AlgebraicAlt, ''PrimitiveAlt, ''DefaultAlt
, ''PrimOp, ''Atom ]
instance Lift NonDefaultAlts where
lift NoNonDefaultAlts = [| NoNonDefaultAlts |]
lift (AlgebraicAlts alts) =
[| AlgebraicAlts (NonEmpty.fromList $(lift (toList alts))) |]
lift (PrimitiveAlts alts) =
[| PrimitiveAlts (NonEmpty.fromList $(lift (toList alts))) |]
instance Lift Binds where
lift (Binds binds) = [| Binds (M.fromList $(lift (M.assocs binds))) |]
instance Lift Constr where
lift (Constr con) = [| Constr (T.pack $(lift (T.unpack con))) |]
instance Lift Var where
lift (Var var) = [| Var (T.pack $(lift (T.unpack var))) |]
semicolonTerminated :: [Doc StgiAnn] -> Doc StgiAnn
semicolonTerminated = align . vsep . punctuate (annotate (AstAnn Semicolon) ";")
instance PrettyStgi Program where
prettyStgi (Program binds) = prettyStgi binds
instance PrettyStgi Binds where
prettyStgi (Binds bs) =
(semicolonTerminated . map prettyBinding . M.assocs) bs
where
prettyBinding (var, lambda) =
prettyStgi var <+> "=" <+> prettyStgi lambda
prettyLambda
:: ([Var] -> Doc StgiAnn)
-> LambdaForm
-> Doc StgiAnn
prettyLambda pprFree (LambdaForm free upd bound expr) =
(prettyExp . prettyUpd . prettyBound . prettyFree) "\\"
where
prettyFree | null free = id
| otherwise = (<> lparen <> pprFree free <> rparen)
prettyUpd = (<+> case upd of Update -> "=>"
NoUpdate -> "->" )
prettyBound | null bound = id
| null free = (<> hsep (map prettyStgi bound))
| otherwise = (<+> hsep (map prettyStgi bound))
prettyExp = (<+> prettyStgi expr)
instance PrettyStgi LambdaForm where
prettyStgi = prettyLambda (hsep . map prettyStgi)
instance PrettyStgi Rec where
prettyStgi = \case
NonRecursive -> ""
Recursive -> "rec"
instance PrettyStgi Expr where
prettyStgi = \case
Let rec binds expr ->
let inBlock = indent 4 (annotate (AstAnn Keyword) "in" <+> prettyStgi expr)
bindingBlock = line <> indent 4 (
annotate (AstAnn Keyword) ("let" <> prettyStgi rec) <+> prettyStgi binds )
in vsep [bindingBlock, inBlock]
Case expr alts -> vsep [ hsep [ annotate (AstAnn Keyword) "case"
, prettyStgi expr
, annotate (AstAnn Keyword) "of" ]
, indent 4 (align (prettyStgi alts)) ]
AppF var [] -> prettyStgi var
AppF var args -> prettyStgi var <+> hsep (map prettyStgi args)
AppC con [] -> prettyStgi con
AppC con args -> prettyStgi con <+> hsep (map prettyStgi args)
AppP op arg1 arg2 -> prettyStgi op <+> prettyStgi arg1 <+> prettyStgi arg2
LitE lit -> prettyStgi lit
instance PrettyStgi Alts where
prettyStgi (Alts NoNonDefaultAlts def) = prettyStgi def
prettyStgi (Alts (AlgebraicAlts alts) def) =
semicolonTerminated (map prettyStgi (toList alts) <> [prettyStgi def])
prettyStgi (Alts (PrimitiveAlts alts) def) =
semicolonTerminated (map prettyStgi (toList alts) <> [prettyStgi def])
instance PrettyStgi AlgebraicAlt where
prettyStgi (AlgebraicAlt con [] expr)
= prettyStgi con <+> "->" <+> prettyStgi expr
prettyStgi (AlgebraicAlt con args expr)
= prettyStgi con <+> hsep (map prettyStgi args) <+> "->" <+> prettyStgi expr
instance PrettyStgi PrimitiveAlt where
prettyStgi (PrimitiveAlt lit expr) =
prettyStgi lit <+> "->" <+> prettyStgi expr
instance PrettyStgi DefaultAlt where
prettyStgi = \case
DefaultNotBound expr -> "default" <+> "->" <+> prettyStgi expr
DefaultBound var expr -> prettyStgi var <+> "->" <+> prettyStgi expr
instance PrettyStgi Literal where
prettyStgi (Literal i) = annotate (AstAnn Prim) (pretty i <> "#")
instance PrettyStgi PrimOp where
prettyStgi op = annotate (AstAnn Prim) (case op of
Add -> "+#"
Sub -> "-#"
Mul -> "*#"
Div -> "/#"
Mod -> "%#"
Eq -> "==#"
Lt -> "<#"
Leq -> "<=#"
Gt -> ">#"
Geq -> ">=#"
Neq -> "/=#" )
instance PrettyStgi Var where
prettyStgi (Var name) = annotate (AstAnn Variable) (pretty name)
instance PrettyStgi Atom where
prettyStgi = \case
AtomVar var -> prettyStgi var
AtomLit lit -> prettyStgi lit
instance PrettyStgi Constr where
prettyStgi (Constr name) = annotate (AstAnn Constructor) (pretty name)
instance NFData Program
instance NFData Binds
instance NFData LambdaForm
instance NFData UpdateFlag
instance NFData Rec
instance NFData Expr
instance NFData Alts
instance NFData NonDefaultAlts
instance NFData AlgebraicAlt
instance NFData PrimitiveAlt
instance NFData DefaultAlt
instance NFData Literal
instance NFData PrimOp
instance NFData Var
instance NFData Atom
instance NFData Constr