{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RankNTypes #-}
{-# OPTIONS_GHC -Wall #-}

module Box.Cont
  ( Cont(..)
  , Cont_(..)
  ) where

import Control.Applicative
import Control.Monad.IO.Class (MonadIO(liftIO))
import Data.Monoid (Monoid(..))
import Data.Semigroup (Semigroup(..))

-- | A continuation similar to `ContT` but where the result type is swallowed by an existential
newtype Cont m a = Cont
  { with :: forall r. (a -> m r) -> m r
  }

instance Functor (Cont m) where
  fmap f mx = Cont (\return_ -> mx `with` \x -> return_ (f x))

instance Applicative (Cont m) where
  pure r = Cont (\return_ -> return_ r)
  mf <*> mx = Cont (\return_ -> mf `with` \f -> mx `with` \x -> return_ (f x))

instance Monad (Cont m) where
  return r = Cont (\return_ -> return_ r)
  ma >>= f = Cont (\return_ -> ma `with` \a -> f a `with` \b -> return_ b)

instance (MonadIO m) => MonadIO (Cont m) where
  liftIO m =
    Cont
      (\return_ -> do
         a <- liftIO m
         return_ a)

instance (Semigroup a) => Semigroup (Cont m a) where
  (<>) = liftA2 (<>)

instance (Functor m, Semigroup a, Monoid a) => Monoid (Cont m a) where
  mempty = pure mempty
  mappend = (<>)

-- | sometimes you have no choice but to void it up
newtype Cont_ m a = Cont_
  { with_ :: (a -> m ()) -> m ()
  }

instance Functor (Cont_ m) where
  fmap f mx = Cont_ (\return_ -> mx `with_` \x -> return_ (f x))

instance Applicative (Cont_ m) where
  pure r = Cont_ (\return_ -> return_ r)
  mf <*> mx = Cont_ (\return_ -> mf `with_` \f -> mx `with_` \x -> return_ (f x))

instance Monad (Cont_ m) where
  return r = Cont_ (\return_ -> return_ r)
  ma >>= f = Cont_ (\return_ -> ma `with_` \a -> f a `with_` \b -> return_ b)

instance (MonadIO m) => MonadIO (Cont_ m) where
  liftIO m =
    Cont_
      (\return_ -> do
         a <- liftIO m
         return_ a)

instance (Semigroup a) => Semigroup (Cont_ m a) where
  (<>) = liftA2 (<>)

instance (Functor m, Semigroup a, Monoid a) => Monoid (Cont_ m a) where
  mempty = pure mempty
  mappend = (<>)