{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE UndecidableInstances #-}

-- |
-- Module      : Jikka.Common.Alpha
-- Description : provides a monad to run alpha-conversion. / alpha 変換用のモナドを提供します。
-- Copyright   : (c) Kimiyuki Onaka, 2020
-- License     : Apache License 2.0
-- Maintainer  : kimiyuki95@gmail.com
-- Stability   : experimental
-- Portability : portable
--
-- `Jikka.Common.Alpha` provides a monad to run alpha-conversion. This monad has only a feature to make unique numbers.
module Jikka.Common.Alpha where

import Control.Arrow (first)
import Control.Monad.Except
import Control.Monad.Identity (Identity (..))
import Control.Monad.Reader
import Control.Monad.Signatures
import Control.Monad.State.Strict
import Control.Monad.Writer.Strict
import Data.Unique
import Language.Haskell.TH (Q)

class Monad m => MonadAlpha m where
  nextCounter :: m Int

newtype AlphaT m a = AlphaT {AlphaT m a -> Int -> m (a, Int)
runAlphaT :: Int -> m (a, Int)}

instance Monad m => MonadAlpha (AlphaT m) where
  nextCounter :: AlphaT m Int
nextCounter = (Int -> m (Int, Int)) -> AlphaT m Int
forall (m :: * -> *) a. (Int -> m (a, Int)) -> AlphaT m a
AlphaT (\Int
i -> (Int, Int) -> m (Int, Int)
forall (m :: * -> *) a. Monad m => a -> m a
return (Int
i, Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1))

instance Functor m => Functor (AlphaT m) where
  fmap :: (a -> b) -> AlphaT m a -> AlphaT m b
fmap a -> b
f (AlphaT Int -> m (a, Int)
x) = (Int -> m (b, Int)) -> AlphaT m b
forall (m :: * -> *) a. (Int -> m (a, Int)) -> AlphaT m a
AlphaT (\Int
i -> ((a, Int) -> (b, Int)) -> m (a, Int) -> m (b, Int)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((a -> b) -> (a, Int) -> (b, Int)
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (b, d) (c, d)
first a -> b
f) (Int -> m (a, Int)
x Int
i))

instance Monad m => Applicative (AlphaT m) where
  pure :: a -> AlphaT m a
pure a
x = (Int -> m (a, Int)) -> AlphaT m a
forall (m :: * -> *) a. (Int -> m (a, Int)) -> AlphaT m a
AlphaT (\Int
i -> (a, Int) -> m (a, Int)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (a
x, Int
i))
  AlphaT Int -> m (a -> b, Int)
f <*> :: AlphaT m (a -> b) -> AlphaT m a -> AlphaT m b
<*> AlphaT Int -> m (a, Int)
x = (Int -> m (b, Int)) -> AlphaT m b
forall (m :: * -> *) a. (Int -> m (a, Int)) -> AlphaT m a
AlphaT ((Int -> m (b, Int)) -> AlphaT m b)
-> (Int -> m (b, Int)) -> AlphaT m b
forall a b. (a -> b) -> a -> b
$ \Int
i -> do
    (a -> b
f, Int
i) <- Int -> m (a -> b, Int)
f Int
i
    (a
x, Int
i) <- Int -> m (a, Int)
x Int
i
    (b, Int) -> m (b, Int)
forall (m :: * -> *) a. Monad m => a -> m a
return (a -> b
f a
x, Int
i)

instance Monad m => Monad (AlphaT m) where
  AlphaT Int -> m (a, Int)
x >>= :: AlphaT m a -> (a -> AlphaT m b) -> AlphaT m b
>>= a -> AlphaT m b
f = (Int -> m (b, Int)) -> AlphaT m b
forall (m :: * -> *) a. (Int -> m (a, Int)) -> AlphaT m a
AlphaT ((Int -> m (b, Int)) -> AlphaT m b)
-> (Int -> m (b, Int)) -> AlphaT m b
forall a b. (a -> b) -> a -> b
$ \Int
i -> do
    (a
x, Int
i) <- Int -> m (a, Int)
x Int
i
    AlphaT m b -> Int -> m (b, Int)
forall (m :: * -> *) a. AlphaT m a -> Int -> m (a, Int)
runAlphaT (a -> AlphaT m b
f a
x) Int
i

instance MonadFix m => MonadFix (AlphaT m) where
  mfix :: (a -> AlphaT m a) -> AlphaT m a
mfix a -> AlphaT m a
f = (Int -> m (a, Int)) -> AlphaT m a
forall (m :: * -> *) a. (Int -> m (a, Int)) -> AlphaT m a
AlphaT (\Int
i -> ((a, Int) -> m (a, Int)) -> m (a, Int)
forall (m :: * -> *) a. MonadFix m => (a -> m a) -> m a
mfix (\(a, Int)
x -> AlphaT m a -> Int -> m (a, Int)
forall (m :: * -> *) a. AlphaT m a -> Int -> m (a, Int)
runAlphaT (a -> AlphaT m a
f ((a, Int) -> a
forall a b. (a, b) -> a
fst (a, Int)
x)) Int
i))

liftCatch :: Catch e m (a, Int) -> Catch e (AlphaT m) a
liftCatch :: Catch e m (a, Int) -> Catch e (AlphaT m) a
liftCatch Catch e m (a, Int)
catchE AlphaT m a
m e -> AlphaT m a
h = (Int -> m (a, Int)) -> AlphaT m a
forall (m :: * -> *) a. (Int -> m (a, Int)) -> AlphaT m a
AlphaT (\Int
i -> AlphaT m a -> Int -> m (a, Int)
forall (m :: * -> *) a. AlphaT m a -> Int -> m (a, Int)
runAlphaT AlphaT m a
m Int
i Catch e m (a, Int)
`catchE` \e
e -> AlphaT m a -> Int -> m (a, Int)
forall (m :: * -> *) a. AlphaT m a -> Int -> m (a, Int)
runAlphaT (e -> AlphaT m a
h e
e) Int
i)

instance MonadTrans AlphaT where
  lift :: m a -> AlphaT m a
lift m a
m = (Int -> m (a, Int)) -> AlphaT m a
forall (m :: * -> *) a. (Int -> m (a, Int)) -> AlphaT m a
AlphaT ((Int -> m (a, Int)) -> AlphaT m a)
-> (Int -> m (a, Int)) -> AlphaT m a
forall a b. (a -> b) -> a -> b
$ \Int
i -> do
    a
a <- m a
m
    (a, Int) -> m (a, Int)
forall (m :: * -> *) a. Monad m => a -> m a
return (a
a, Int
i)

instance MonadError e m => MonadError e (AlphaT m) where
  throwError :: e -> AlphaT m a
throwError = m a -> AlphaT m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m a -> AlphaT m a) -> (e -> m a) -> e -> AlphaT m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. e -> m a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError
  catchError :: AlphaT m a -> (e -> AlphaT m a) -> AlphaT m a
catchError = Catch e m (a, Int) -> AlphaT m a -> (e -> AlphaT m a) -> AlphaT m a
forall e (m :: * -> *) a.
Catch e m (a, Int) -> Catch e (AlphaT m) a
liftCatch Catch e m (a, Int)
forall e (m :: * -> *) a.
MonadError e m =>
m a -> (e -> m a) -> m a
catchError

instance MonadIO m => MonadIO (AlphaT m) where
  liftIO :: IO a -> AlphaT m a
liftIO = m a -> AlphaT m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m a -> AlphaT m a) -> (IO a -> m a) -> IO a -> AlphaT m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO

evalAlphaT :: Functor m => AlphaT m a -> Int -> m a
evalAlphaT :: AlphaT m a -> Int -> m a
evalAlphaT AlphaT m a
f Int
i = (a, Int) -> a
forall a b. (a, b) -> a
fst ((a, Int) -> a) -> m (a, Int) -> m a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> AlphaT m a -> Int -> m (a, Int)
forall (m :: * -> *) a. AlphaT m a -> Int -> m (a, Int)
runAlphaT AlphaT m a
f Int
i

instance MonadAlpha m => MonadAlpha (ExceptT e m) where
  nextCounter :: ExceptT e m Int
nextCounter = m Int -> ExceptT e m Int
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift m Int
forall (m :: * -> *). MonadAlpha m => m Int
nextCounter

instance MonadAlpha m => MonadAlpha (ReaderT r m) where
  nextCounter :: ReaderT r m Int
nextCounter = m Int -> ReaderT r m Int
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift m Int
forall (m :: * -> *). MonadAlpha m => m Int
nextCounter

instance MonadAlpha m => MonadAlpha (StateT s m) where
  nextCounter :: StateT s m Int
nextCounter = m Int -> StateT s m Int
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift m Int
forall (m :: * -> *). MonadAlpha m => m Int
nextCounter

instance (MonadAlpha m, Monoid w) => MonadAlpha (WriterT w m) where
  nextCounter :: WriterT w m Int
nextCounter = m Int -> WriterT w m Int
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift m Int
forall (m :: * -> *). MonadAlpha m => m Int
nextCounter

evalAlpha :: AlphaT Identity a -> Int -> a
evalAlpha :: AlphaT Identity a -> Int -> a
evalAlpha AlphaT Identity a
f Int
i = Identity a -> a
forall a. Identity a -> a
runIdentity (AlphaT Identity a -> Int -> Identity a
forall (m :: * -> *) a. Functor m => AlphaT m a -> Int -> m a
evalAlphaT AlphaT Identity a
f Int
i)

resetAlphaT :: Monad m => Int -> AlphaT m ()
resetAlphaT :: Int -> AlphaT m ()
resetAlphaT Int
i = (Int -> m ((), Int)) -> AlphaT m ()
forall (m :: * -> *) a. (Int -> m (a, Int)) -> AlphaT m a
AlphaT ((Int -> m ((), Int)) -> AlphaT m ())
-> (Int -> m ((), Int)) -> AlphaT m ()
forall a b. (a -> b) -> a -> b
$ \Int
_ -> ((), Int) -> m ((), Int)
forall (m :: * -> *) a. Monad m => a -> m a
return ((), Int
i)

instance MonadAlpha IO where
  nextCounter :: IO Int
nextCounter = Unique -> Int
hashUnique (Unique -> Int) -> IO Unique -> IO Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IO Unique
newUnique

instance MonadAlpha Q where
  nextCounter :: Q Int
nextCounter = IO Int -> Q Int
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO IO Int
forall (m :: * -> *). MonadAlpha m => m Int
nextCounter