{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE TypeFamilies #-}

module TypedFsm.Core where

import Data.IFunctor (
  At (..),
  IFunctor (..),
  IMonad (..),
  returnAt,
  type (~>),
 )

import Data.Kind (Type)
import Data.Singletons (SingI)

-- | The state-transition type class
class StateTransMsg ps where
  data Msg ps (st :: ps) (st' :: ps)

{- | Core AST

Essentially all we do is build this AST and then interpret it.

`Operate m ia st` is an instance of `IMonad`, and it contains an `m` internally


Typed-fsm only contains two core functions: `getInput`, `liftm`.
We use these two functions to build Operate.

The overall behavior is as follows: constantly reading messages from the outside and converting them into internal monads action.
-}
data Operate :: (Type -> Type) -> (ps -> Type) -> ps -> Type where
  IReturn :: ia (mode :: ps) -> Operate m ia mode
  LiftM
    :: (SingI mode, SingI mode')
    => m (Operate m ia mode')
    -> Operate m ia mode
  In
    :: forall ps m (from :: ps) ia
     . (Msg ps from ~> Operate m ia)
    -> Operate m ia from

instance (Functor m) => IFunctor (Operate m) where
  imap :: forall (a :: k1 -> *) (b :: k1 -> *).
(a ~> b) -> Operate m a ~> Operate m b
imap a ~> b
f = \case
    IReturn a x
ia -> b x -> Operate m b x
forall ps (ia :: ps -> *) (mode :: ps) (m :: * -> *).
ia mode -> Operate m ia mode
IReturn (a x -> b x
a ~> b
f a x
ia)
    LiftM m (Operate m a mode')
f' -> m (Operate m b mode') -> Operate m b x
forall {ps} (mode :: ps) (mode' :: ps) (m :: * -> *)
       (ia :: ps -> *).
(SingI mode, SingI mode') =>
m (Operate m ia mode') -> Operate m ia mode
LiftM ((Operate m a mode' -> Operate m b mode')
-> m (Operate m a mode') -> m (Operate m b mode')
forall a b. (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((a ~> b) -> Operate m a ~> Operate m b
forall {k} {k1} (f :: (k -> *) -> k1 -> *) (a :: k -> *)
       (b :: k -> *).
IFunctor f =>
(a ~> b) -> f a ~> f b
forall (a :: k1 -> *) (b :: k1 -> *).
(a ~> b) -> Operate m a ~> Operate m b
imap a x -> b x
a ~> b
f) m (Operate m a mode')
f')
    In Msg k1 x ~> Operate m a
cont -> (Msg k1 x ~> Operate m b) -> Operate m b x
forall ps (m :: * -> *) (from :: ps) (ia :: ps -> *).
(Msg ps from ~> Operate m ia) -> Operate m ia from
In ((a ~> b) -> Operate m a ~> Operate m b
forall {k} {k1} (f :: (k -> *) -> k1 -> *) (a :: k -> *)
       (b :: k -> *).
IFunctor f =>
(a ~> b) -> f a ~> f b
forall (a :: k1 -> *) (b :: k1 -> *).
(a ~> b) -> Operate m a ~> Operate m b
imap a x -> b x
a ~> b
f (Operate m a x -> Operate m b x)
-> (Msg k1 x x -> Operate m a x) -> Msg k1 x x -> Operate m b x
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Msg k1 x x -> Operate m a x
Msg k1 x ~> Operate m a
cont)
instance (Functor m) => IMonad (Operate m) where
  ireturn :: forall (a :: k -> *). a ~> Operate m a
ireturn = a x -> Operate m a x
forall ps (ia :: ps -> *) (mode :: ps) (m :: * -> *).
ia mode -> Operate m ia mode
IReturn
  ibind :: forall (a :: k -> *) (b :: k -> *).
(a ~> Operate m b) -> Operate m a ~> Operate m b
ibind a ~> Operate m b
f = \case
    IReturn a x
ia -> (a x -> Operate m b x
a ~> Operate m b
f a x
ia)
    LiftM m (Operate m a mode')
m -> m (Operate m b mode') -> Operate m b x
forall {ps} (mode :: ps) (mode' :: ps) (m :: * -> *)
       (ia :: ps -> *).
(SingI mode, SingI mode') =>
m (Operate m ia mode') -> Operate m ia mode
LiftM ((Operate m a mode' -> Operate m b mode')
-> m (Operate m a mode') -> m (Operate m b mode')
forall a b. (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((a ~> Operate m b) -> Operate m a ~> Operate m b
forall {k} (m :: (k -> *) -> k -> *) (a :: k -> *) (b :: k -> *).
IMonad m =>
(a ~> m b) -> m a ~> m b
forall (a :: k -> *) (b :: k -> *).
(a ~> Operate m b) -> Operate m a ~> Operate m b
ibind a x -> Operate m b x
a ~> Operate m b
f) m (Operate m a mode')
m)
    In Msg k x ~> Operate m a
cont -> (Msg k x ~> Operate m b) -> Operate m b x
forall ps (m :: * -> *) (from :: ps) (ia :: ps -> *).
(Msg ps from ~> Operate m ia) -> Operate m ia from
In ((a ~> Operate m b) -> Operate m a ~> Operate m b
forall {k} (m :: (k -> *) -> k -> *) (a :: k -> *) (b :: k -> *).
IMonad m =>
(a ~> m b) -> m a ~> m b
forall (a :: k -> *) (b :: k -> *).
(a ~> Operate m b) -> Operate m a ~> Operate m b
ibind a x -> Operate m b x
a ~> Operate m b
f (Operate m a x -> Operate m b x)
-> (Msg k x x -> Operate m a x) -> Msg k x x -> Operate m b x
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Msg k x x -> Operate m a x
Msg k x ~> Operate m a
cont)

-- | get messages from outside
getInput :: forall ps m (from :: ps). (Functor m) => Operate m (Msg ps from) from
getInput :: forall ps (m :: * -> *) (from :: ps).
Functor m =>
Operate m (Msg ps from) from
getInput = (Msg ps from ~> Operate m (Msg ps from))
-> Operate m (Msg ps from) from
forall ps (m :: * -> *) (from :: ps) (ia :: ps -> *).
(Msg ps from ~> Operate m ia) -> Operate m ia from
In Msg ps from x -> Operate m (Msg ps from) x
Msg ps from ~> Operate m (Msg ps from)
forall {k} (m :: (k -> *) -> k -> *) (a :: k -> *).
IMonad m =>
a ~> m a
forall (a :: ps -> *). a ~> Operate m a
ireturn

-- | lifts the internal `m a` to `Operate m (At a i) i'
liftm :: forall ps m (mode :: ps) a. (Functor m, SingI mode) => m a -> Operate m (At a mode) mode
liftm :: forall ps (m :: * -> *) (mode :: ps) a.
(Functor m, SingI mode) =>
m a -> Operate m (At a mode) mode
liftm m a
m = m (Operate m (At a mode) mode) -> Operate m (At a mode) mode
forall {ps} (mode :: ps) (mode' :: ps) (m :: * -> *)
       (ia :: ps -> *).
(SingI mode, SingI mode') =>
m (Operate m ia mode') -> Operate m ia mode
LiftM (a -> Operate m (At a mode) mode
forall {k1} (m :: (k1 -> *) -> k1 -> *) a (k2 :: k1).
IMonad m =>
a -> m (At a k2) k2
returnAt (a -> Operate m (At a mode) mode)
-> m a -> m (Operate m (At a mode) mode)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> m a
m)