-- SPDX-FileCopyrightText: 2022 Oxhead Alpha
-- SPDX-License-Identifier: LicenseRef-MIT-OA

-- | A poor man's extensible reader effects via nested 'ReaderT'.
module Morley.Util.MultiReader
  ( MultiReaderT
  , MonadMultiReaderT
  , ChangeMultiReaderBase
  , asks'
  , ask'
  , local'
  , mapMultiReaderT
  ) where

import Control.Monad.Reader (mapReaderT)
import GHC.TypeLits (ErrorMessage(..), TypeError)

import Morley.Util.Peano

-- | Convenience type family to build a stack of multiple 'ReaderT'.
type family MultiReaderT (xs :: [Type]) (m :: Type -> Type) :: Type -> Type where
  MultiReaderT (x ': xs) m = ReaderT x (MultiReaderT xs m)
  MultiReaderT '[] m = m

-- | Convenience constraint synonym.
--
-- Required for `asks'`, `ask'`, `local'` and 'mapMultiReaderT'
type MonadMultiReaderT m base =
  (m ~ MultiReaderT (MultiReaderIso m) base, MonadMultiReaderMap (MultiReaderIso m))

-- | Replace the base monad for a nested 'ReaderT' stack.
type ChangeMultiReaderBase m newBase = MultiReaderT (MultiReaderIso m) newBase

-- | Find the index of the first occurrence of the first argument in the second
-- argument as a Peano number.
--
-- Essentially a type-level version of 'find'.
--
-- Raises a type error if the element is not found.
type family MultiReaderDepth (r :: Type) (rs :: [Type]) :: Peano where
  MultiReaderDepth r (r ': _) = 'Z
  MultiReaderDepth r (_ ': rs) = 'S (MultiReaderDepth r rs)
  MultiReaderDepth r '[] = TypeError (
    'Text "MultiReaderT does not have a reader environment" ':$$:
    'ShowType r ':$$: 'Text "anywhere in the stack."
    )

-- | Given a transformer stack of nested 'ReaderT', get a list of environments.
-- This type family is essentially a witness of isomorphism between a stack of
-- 'ReaderT' and a type-level list of reader environments.
--
-- This is useful because 'MultiReaderT' can't have an injectivity annotation.
type family MultiReaderIso (m :: Type -> Type) :: [Type] where
  MultiReaderIso (ReaderT r m) = r ': MultiReaderIso m
  MultiReaderIso _ = '[]

-- | Typeclass implementing versions of 'ask' and 'local' that aren't
-- constrained by a functional dependency.
class (Monad m, n ~ MultiReaderDepth r (MultiReaderIso m))
  => MultiReader (n :: Peano) r m where
  -- | Unconstrained version of 'ask'. Lifts the appropriate number of times
  -- depending on the type @r@.
  ask' :: m r

  -- | Unconstrained version of 'local'. Maps the appropriate number of times
  -- depending on the type @r@.
  local' :: (r -> r) -> m a -> m a

instance (Monad m) => MultiReader 'Z x (ReaderT x m) where
  ask' :: ReaderT x m x
ask' = ReaderT x m x
forall r (m :: * -> *). MonadReader r m => m r
ask
  local' :: forall a. (x -> x) -> ReaderT x m a -> ReaderT x m a
local' = (x -> x) -> ReaderT x m a -> ReaderT x m a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local

instance
  ( MultiReader n r m, Monad m, 'S n ~ MultiReaderDepth r (x ': MultiReaderIso m) )
  => MultiReader ('S n) r (ReaderT x m) where
  ask' :: ReaderT x m r
ask' = m r -> ReaderT x m r
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift m r
forall (n :: Peano) r (m :: * -> *). MultiReader n r m => m r
ask'
  local' :: forall a. (r -> r) -> ReaderT x m a -> ReaderT x m a
local' = (m a -> m a) -> ReaderT x m a -> ReaderT x m a
forall (m :: * -> *) a (n :: * -> *) b r.
(m a -> n b) -> ReaderT r m a -> ReaderT r n b
mapReaderT ((m a -> m a) -> ReaderT x m a -> ReaderT x m a)
-> ((r -> r) -> m a -> m a)
-> (r -> r)
-> ReaderT x m a
-> ReaderT x m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (r -> r) -> m a -> m a
forall (n :: Peano) r (m :: * -> *) a.
MultiReader n r m =>
(r -> r) -> m a -> m a
local'

-- | Class implementing 'mapReaderT' over a stack of 'ReaderT'.
class MonadMultiReaderMap xs where
  -- | 'mapReaderT', only it maps over the whole nested 'ReaderT' stack, and not
  -- just one level.
  mapMultiReaderT
    :: ( m' ~ MultiReaderT xs m
       , n' ~ MultiReaderT xs n
       , xs ~ MultiReaderIso m'
       )
    => (m a -> n b) -> m' a -> n' b

instance MonadMultiReaderMap '[] where
  mapMultiReaderT :: forall (m' :: * -> *) (m :: * -> *) (n' :: * -> *) (n :: * -> *) a
       b.
(m' ~ MultiReaderT '[] m, n' ~ MultiReaderT '[] n,
 '[] ~ MultiReaderIso m') =>
(m a -> n b) -> m' a -> n' b
mapMultiReaderT m a -> n b
f = m' a -> n' b
m a -> n b
f

instance (MonadMultiReaderMap xs) => MonadMultiReaderMap (x ': xs) where
  mapMultiReaderT :: forall (m' :: * -> *) (m :: * -> *) (n' :: * -> *) (n :: * -> *) a
       b.
(m' ~ MultiReaderT (x : xs) m, n' ~ MultiReaderT (x : xs) n,
 (x : xs) ~ MultiReaderIso m') =>
(m a -> n b) -> m' a -> n' b
mapMultiReaderT m a -> n b
f = (MultiReaderT xs m a -> MultiReaderT xs n b)
-> ReaderT x (MultiReaderT xs m) a
-> ReaderT x (MultiReaderT xs n) b
forall (m :: * -> *) a (n :: * -> *) b r.
(m a -> n b) -> ReaderT r m a -> ReaderT r n b
mapReaderT ((m a -> n b) -> MultiReaderT xs m a -> MultiReaderT xs n b
forall (xs :: [*]) (m' :: * -> *) (m :: * -> *) (n' :: * -> *)
       (n :: * -> *) a b.
(MonadMultiReaderMap xs, m' ~ MultiReaderT xs m,
 n' ~ MultiReaderT xs n, xs ~ MultiReaderIso m') =>
(m a -> n b) -> m' a -> n' b
mapMultiReaderT m a -> n b
f)

-- | Unconstrained version of 'asks'. @asks' f = fmap f ask'@.
asks' :: forall m r (a :: Type) n. MultiReader n r m => (r -> a) -> m a
asks' :: forall (m :: * -> *) r a (n :: Peano).
MultiReader n r m =>
(r -> a) -> m a
asks' r -> a
f = r -> a
f (r -> a) -> m r -> m a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> m r
forall (n :: Peano) r (m :: * -> *). MultiReader n r m => m r
ask'