{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}
-- | Symantic for 'Monad'.
module Language.Symantic.Lib.Monad where

import Control.Monad (Monad)
import Prelude hiding (Monad(..))
import qualified Control.Monad as Monad

import Language.Symantic
import Language.Symantic.Lib.Function (a0, b1, c2)
import Language.Symantic.Lib.Unit (tyUnit)
import Language.Symantic.Lib.Bool (tyBool)

-- * Class 'Sym_Monad'
type instance Sym Monad = Sym_Monad
class Sym_Monad term where
	return :: Monad m       => term a -> term (m a)
	(>>=)  :: Monad m       => term (m a) -> term (a -> m b) -> term (m b); infixl 1 >>=
	join   :: Monad m       => term (m (m a)) -> term (m a)
	when   :: Applicative f => term Bool -> term (f ()) -> term (f ())
	(>=>)  :: Monad m       => term (a -> m b) -> term (b -> m c) -> term (a -> m c); infixr 1 >=>
	
	default return :: Sym_Monad (UnT term) => Trans term => Monad m       => term a -> term (m a)
	default (>>=)  :: Sym_Monad (UnT term) => Trans term => Monad m       => term (m a) -> term (a -> m b) -> term (m b)
	default join   :: Sym_Monad (UnT term) => Trans term => Monad m       => term (m (m a)) -> term (m a)
	default when   :: Sym_Monad (UnT term) => Trans term => Applicative f => term Bool -> term (f ()) -> term (f ())
	default (>=>)  :: Sym_Monad (UnT term) => Trans term => Monad m       => term (a -> m b) -> term (b -> m c) -> term (a -> m c)
	
	return = trans1 return
	(>>=)  = trans2 (>>=)
	join   = trans1 join
	when   = trans2 when
	(>=>)  = trans2 (>=>)

-- Interpreting
instance Sym_Monad Eval where
	return = eval1 Monad.return
	(>>=)  = eval2 (Monad.>>=)
	join   = eval1 Monad.join
	when   = eval2 Monad.when
	(>=>)  = eval2 (Monad.>=>)
instance Sym_Monad View where
	return = view1 "return"
	(>>=)  = viewInfix ">>=" (infixL 1)
	join   = view1 "join"
	when   = view2 "when"
	(>=>)  = viewInfix ">=>" (infixR 1)
instance (Sym_Monad r1, Sym_Monad r2) => Sym_Monad (Dup r1 r2) where
	return = dup1 @Sym_Monad return
	(>>=)  = dup2 @Sym_Monad (>>=)
	join   = dup1 @Sym_Monad join
	when   = dup2 @Sym_Monad when
	(>=>)  = dup2 @Sym_Monad (>=>)

-- Transforming
instance (Sym_Monad term, Sym_Lambda term) => Sym_Monad (BetaT term)

-- Typing
instance NameTyOf Monad where
	nameTyOf _c = ["Monad"] `Mod` "Monad"
instance FixityOf Monad
instance ClassInstancesFor Monad
instance TypeInstancesFor Monad

-- Compiling
instance Gram_Term_AtomsFor src ss g Monad
instance (Source src, SymInj ss Monad) => ModuleFor src ss Monad where
	moduleFor = ["Monad"] `moduleWhere`
	 [ "return" := teMonad_return
	 , "join"   := teMonad_join
	 , "when"   := teMonad_when
	 , ">>=" `withInfixL` 1 := teMonad_bind
	 , ">=>" `withInfixR` 1 := teMonad_kleisli_l2r
	 ]

-- ** 'Type's
tyMonad :: Source src => Type src vs m -> Type src vs (Monad m)
tyMonad m = tyConstLen @(K Monad) @Monad (lenVars m) `tyApp` m

m0 :: Source src => LenInj vs => KindInj (K m) =>
     Type src (Proxy m ': vs) m
m0 = tyVar "m" varZ

m1 :: Source src => LenInj vs => KindInj (K m) =>
     Type src (a ': Proxy m ': vs) m
m1 = tyVar "m" $ VarS varZ

m2 :: Source src => LenInj vs => KindInj (K m) =>
     Type src (a ': b ': Proxy m ': vs) m
m2 = tyVar "m" $ VarS $ VarS varZ

m3 :: Source src => LenInj vs => KindInj (K m) =>
     Type src (a ': b ': c ': Proxy m ': vs) m
m3 = tyVar "m" $ VarS $ VarS $ VarS varZ

-- ** 'Term's
teMonad_return :: TermDef Monad '[Proxy a, Proxy m] (Monad m #> (a -> m a))
teMonad_return = Term (tyMonad m1) (a0 ~> m1 `tyApp` a0) $ teSym @Monad $ lam1 return

teMonad_bind :: TermDef Monad '[Proxy a, Proxy b, Proxy m] (Monad m #> (m a -> (a -> m b) -> m b))
teMonad_bind = Term (tyMonad m2) (m2 `tyApp` a0 ~> (a0 ~> m2 `tyApp` b1) ~> m2 `tyApp` b1) $ teSym @Monad $ lam2 (>>=)

teMonad_join :: TermDef Monad '[Proxy a, Proxy m] (Monad m #> (m (m a) -> m a))
teMonad_join = Term (tyMonad m1) (m1 `tyApp` (m1 `tyApp` a0) ~> m1 `tyApp` a0) $ teSym @Monad $ lam1 join

teMonad_kleisli_l2r :: TermDef Monad '[Proxy a, Proxy b, Proxy c, Proxy m] (Monad m #> ((a -> m b) -> (b -> m c) -> (a -> m c)))
teMonad_kleisli_l2r = Term (tyMonad m3) ((a0 ~> m3 `tyApp` b1) ~> (b1 ~> m3 `tyApp` c2) ~> (a0 ~> m3 `tyApp` c2)) $ teSym @Monad $ lam2 (>=>)

teMonad_when :: TermDef Monad '[Proxy m] (Monad m #> (Bool -> m () -> m ()))
teMonad_when = Term (tyMonad m0) (tyBool ~> m0 `tyApp` tyUnit ~> m0 `tyApp` tyUnit) $ teSym @Monad $ lam2 when