{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE EmptyCase #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeAbstractions #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}

module TypedSession.State.Utils where

import Control.Carrier.Fresh.Strict
import Control.Carrier.State.Strict
import Control.Effect.Writer
import Control.Monad
import Data.IntMap (IntMap)
import qualified Data.IntMap as IntMap
import qualified Data.List as L
import Data.Maybe (fromJust, fromMaybe)
import Data.Sequence (Seq)
import qualified Data.Sequence as Seq
import Data.Set (Set)
import qualified Data.Set as Set
import Data.Traversable (for)
import qualified TypedSession.State.Constraint as C
import TypedSession.State.Type
import Prelude hiding (traverse)

------------------------

restoreWrapper
  :: forall s sig m a
   . (Has (State s) sig m) => m a -> m a
restoreWrapper :: forall s (sig :: (* -> *) -> * -> *) (m :: * -> *) a.
Has (State s) sig m =>
m a -> m a
restoreWrapper m a
m = do
  st <- forall s (sig :: (* -> *) -> * -> *) (m :: * -> *).
Has (State s) sig m =>
m s
get @s
  a <- m
  put st
  pure a

getFirstMsgInfo :: Protocol eta r bst -> Maybe (r, r)
getFirstMsgInfo :: forall eta r bst. Protocol eta r bst -> Maybe (r, r)
getFirstMsgInfo = \case
  MsgOrLabel eta r
msgOrLabel :> Protocol eta r bst
prots -> case MsgOrLabel eta r
msgOrLabel of
    Msg XMsg eta
_ String
_ [[String]]
_ r
from r
to -> (r, r) -> Maybe (r, r)
forall a. a -> Maybe a
Just (r
from, r
to)
    MsgOrLabel eta r
_ -> Protocol eta r bst -> Maybe (r, r)
forall eta r bst. Protocol eta r bst -> Maybe (r, r)
getFirstMsgInfo Protocol eta r bst
prots
  Protocol eta r bst
_ -> Maybe (r, r)
forall a. Maybe a
Nothing

getAllMsgInfo :: Protocol eta r bst -> [(r, r)]
getAllMsgInfo :: forall eta r bst. Protocol eta r bst -> [(r, r)]
getAllMsgInfo = \case
  MsgOrLabel eta r
msgOrLabel :> Protocol eta r bst
prots -> case MsgOrLabel eta r
msgOrLabel of
    Msg XMsg eta
_ String
_ [[String]]
_ r
from r
to -> (r
from, r
to) (r, r) -> [(r, r)] -> [(r, r)]
forall a. a -> [a] -> [a]
: Protocol eta r bst -> [(r, r)]
forall eta r bst. Protocol eta r bst -> [(r, r)]
getAllMsgInfo Protocol eta r bst
prots
    MsgOrLabel eta r
_ -> Protocol eta r bst -> [(r, r)]
forall eta r bst. Protocol eta r bst -> [(r, r)]
getAllMsgInfo Protocol eta r bst
prots
  Branch XBranch eta
_ r
_ String
_ [BranchSt eta r bst]
ls -> (BranchSt eta r bst -> [(r, r)])
-> [BranchSt eta r bst] -> [(r, r)]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (\(BranchSt XBranchSt eta
_ bst
_ [[String]]
_ Protocol eta r bst
prots) -> Protocol eta r bst -> [(r, r)]
forall eta r bst. Protocol eta r bst -> [(r, r)]
getAllMsgInfo Protocol eta r bst
prots) [BranchSt eta r bst]
ls
  Goto XGoto eta
_ Int
_ -> []
  Terminal XTerminal eta
_ -> []

tellSeq :: (Has (Writer (Seq a)) sig m) => [a] -> m ()
tellSeq :: forall a (sig :: (* -> *) -> * -> *) (m :: * -> *).
Has (Writer (Seq a)) sig m =>
[a] -> m ()
tellSeq [a]
ls = Seq a -> m ()
forall w (sig :: (* -> *) -> * -> *) (m :: * -> *).
Has (Writer w) sig m =>
w -> m ()
tell ([a] -> Seq a
forall a. [a] -> Seq a
Seq.fromList [a]
ls)

replaceList :: (Has (State (Set Int)) sig m) => C.SubMap -> [Int] -> m [Int]
replaceList :: forall (sig :: (* -> *) -> * -> *) (m :: * -> *).
Has (State (Set Int)) sig m =>
SubMap -> [Int] -> m [Int]
replaceList SubMap
sbm [Int]
ls = do
  [Int] -> (Int -> m Int) -> m [Int]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
t a -> (a -> f b) -> f (t b)
for [Int]
ls ((Int -> m Int) -> m [Int]) -> (Int -> m Int) -> m [Int]
forall a b. (a -> b) -> a -> b
$ \Int
k -> do
    let newVal :: Int
newVal = Int -> Maybe Int -> Int
forall a. a -> Maybe a -> a
fromMaybe Int
k (Maybe Int -> Int) -> Maybe Int -> Int
forall a b. (a -> b) -> a -> b
$ Int -> SubMap -> Maybe Int
forall a. Int -> IntMap a -> Maybe a
IntMap.lookup Int
k SubMap
sbm
    (Set Int -> Set Int) -> m ()
forall s (sig :: (* -> *) -> * -> *) (m :: * -> *).
Has (State s) sig m =>
(s -> s) -> m ()
modify (Int -> Set Int -> Set Int
forall a. Ord a => a -> Set a -> Set a
Set.insert Int
newVal)
    Int -> m Int
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
newVal

replaceVal :: IntMap Int -> Int -> Int
replaceVal :: SubMap -> Int -> Int
replaceVal SubMap
sbm Int
k = Int -> Maybe Int -> Int
forall a. a -> Maybe a -> a
fromMaybe (String -> Int
forall a. HasCallStack => String -> a
error String
internalError) (Maybe Int -> Int) -> Maybe Int -> Int
forall a b. (a -> b) -> a -> b
$ Int -> SubMap -> Maybe Int
forall a. Int -> IntMap a -> Maybe a
IntMap.lookup Int
k SubMap
sbm

rRange :: forall r. (Enum r, Bounded r) => [r]
rRange :: forall r. (Enum r, Bounded r) => [r]
rRange = [forall a. Bounded a => a
minBound @r .. r
forall a. Bounded a => a
maxBound]