{-# OPTIONS_GHC -Wall #-}

{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE UndecidableInstances #-}

module Data.Id where

import Control.Applicative ( Alternative(empty, (<|>)) )
import Control.Lens
    ( Wrapped(..),
      Rewrapped,
      Iso,
      Lens',
      Prism,
      Getter,
      iso,
      over,
      _Wrapped,
      view,
      _Just,
      _Unwrapped,
      re,
      _Nothing )
import Control.Monad.Cont ( MonadCont(..) )
import Control.Monad.Error.Class ( MonadError(..) )
import Control.Monad.IO.Class ( MonadIO(..) )
import Control.Monad.Reader.Class ( MonadReader(..) )
import Control.Monad.RWS.Class ( MonadRWS )
import Control.Monad.State.Class ( MonadState(..) )
import Control.Monad.Trans.Class ( MonadTrans(..) )
import Control.Monad.Writer( MonadWriter(..) )
import Data.Functor.Alt ( Alt((<!>)) )
import Data.Functor.Apply ( Apply((<.>)) )
import Data.Functor.Bind ( Bind((>>-)) )
import Data.Functor.Bind.Trans ( BindTrans(..) )
import Data.Functor.Classes
    ( Eq1(..), Ord1(..), Show1(liftShowsPrec) )
import Data.Functor.Identity ( Identity(Identity) )
import Data.Semigroup.Foldable ( Foldable1(foldMap1) )
import Data.Semigroup.Traversable ( Traversable1(traverse1) )
import Data.Tagged ( Tagged(Tagged), untag )
import Data.Proxy ( Proxy )

newtype Id f a =
  Id (f a)
  deriving (Eq, Ord, Show)

instance Eq1 f => Eq1 (Id f) where
  liftEq f (Id x) (Id y) =
    liftEq f x y

instance Ord1 f => Ord1 (Id f) where
  liftCompare f (Id x) (Id y) =
    liftCompare f x y

instance Show1 f => Show1 (Id f) where
  liftShowsPrec sp l d (Id x) =
    ("Id (" <>) . liftShowsPrec sp l d x . (")" <>)

instance (Id f a ~ x) =>
  Rewrapped (Id f a') x

instance Wrapped (Id f a) where
  type Unwrapped (Id f a) =
    f a
  _Wrapped' =
    iso (\(Id x) -> x) Id

__Wrapped ::
  Iso
    (Id f a)
    (Id f' a')
    (f a)
    (f' a')
__Wrapped =
  iso
    (\(Id x) -> x)
    Id

class HasId a f a' | a -> f a' where
  id' ::
    Lens' a (Id f a')

instance HasId (Id f a') f a' where
  id' =
    id

class AsId a f a' | a -> f a' where
  _Id ::
    Lens' a (Id f a')

instance AsId (Id f a') f a' where
  _Id =
    id

instance Semigroup (f a) => Semigroup (Id f a) where
  Id x <> Id y =
    Id (x <> y)

instance Monoid (f a) => Monoid (Id f a) where
  mempty =
    Id mempty

instance Functor f => Functor (Id f) where
  fmap =
    over _Wrapped . fmap

instance Apply f => Apply (Id f) where
  Id x <.> Id y =
    Id (x <.> y)

instance Applicative f => Applicative (Id f) where
  Id x <*> Id y =
    Id (x <*> y)
  pure =
    Id . pure

instance Alt f => Alt (Id f) where
  Id x <!> Id y =
    Id (x <!> y)

instance Alternative f => Alternative (Id f) where
  Id x <|> Id y =
    Id (x <|> y)
  empty =
    Id empty

instance Bind f => Bind (Id f) where
  Id x >>- f =
    Id (x >>- view _Wrapped . f)

instance Monad f => Monad (Id f) where
  Id x >>= f =
    Id (x >>= view _Wrapped . f)

instance Foldable f => Foldable (Id f) where
  foldMap f (Id x) =
    foldMap f x

instance Foldable1 f => Foldable1 (Id f) where
  foldMap1 f (Id x) =
    foldMap1 f x

instance Traversable f => Traversable (Id f) where
  traverse f (Id x) =
    Id <$> traverse f x

instance Traversable1 f => Traversable1 (Id f) where
  traverse1 f (Id x) =
    Id <$> traverse1 f x

instance MonadIO f => MonadIO (Id f) where
  liftIO =
    Id . liftIO

instance MonadTrans Id where
  lift =
    Id

instance BindTrans Id where
  liftB =
    Id

instance MonadError a f => MonadError a (Id f) where
  throwError =
    Id . throwError
  catchError (Id x) f =
    Id (catchError x (view _Wrapped . f))

instance MonadCont f => MonadCont (Id f) where
  callCC f =
    Id (callCC (\k -> view _Wrapped (f (Id . k))))

instance MonadReader a f => MonadReader a (Id f) where
  ask =
    Id ask
  local f (Id x) =
    Id (local f x)
  reader =
    Id . reader

instance MonadWriter a f => MonadWriter a (Id f) where
  writer aw =
    Id (writer aw)
  tell =
    Id . tell
  listen =
    over _Wrapped listen
  pass =
    over _Wrapped pass

instance MonadState a f => MonadState a (Id f) where
  get =
    Id get
  put =
    Id . put
  state =
    Id . state

instance MonadRWS a a a f => MonadRWS a a a (Id f) where

type Id0 a =
  Id Proxy a

type IdB a =
  Id Maybe a

type Id1 a =
  Id Identity a

type IdTagged f a =
  Id (Tagged f) a

{-#
  SPECIALIZE
  idTagged ::
    Iso
      (Id1 a)
      (Id1 a')
      (IdTagged f a)
      (IdTagged f a')
  #-}
idTagged ::
  Rewrapped (f a) (f a) =>
  Iso
    (Id f a)
    (Id1 a')
    (Id (Tagged s) (Unwrapped (f a)))
    (Id (Tagged s') a')
idTagged =
  iso (over __Wrapped (Tagged . view _Wrapped)) (over __Wrapped (Identity . untag))

{-#
  SPECIALIZE
  just ::
    Prism
      (IdB a)
      (IdB a')
      (Id1 a)
      (Id1 a')
  #-}
just ::
  (
    Unwrapped s ~ Maybe (Unwrapped (Unwrapped a)),
    Unwrapped t ~ Maybe (Unwrapped (Unwrapped b)),
    Rewrapped s t, Rewrapped t s, Rewrapped b a,
    Rewrapped a b, Rewrapped (Unwrapped b) (Unwrapped a),
    Rewrapped (Unwrapped a) (Unwrapped b)
  ) =>
  Prism s t a b
just =
  _Wrapped . _Just . _Unwrapped . _Unwrapped

{-#
  SPECIALIZE
  rejust ::
    Getter
      (Id1 a)
      (IdB a)
  #-}
rejust ::
  (
    Unwrapped t ~ Maybe (Unwrapped (Unwrapped b)), Rewrapped t t,
    Rewrapped b b, Rewrapped (Unwrapped b) (Unwrapped b)
  ) =>
  Getter b t
rejust =
  re just

{-#
  SPECIALIZE
  nothing ::
    Prism
      (IdB a)
      (IdB a)
      (Id1 ())
      (Id1 ())
  #-}
nothing ::
  (
    Unwrapped t ~ Maybe x, Unwrapped s ~ Maybe x,
    Unwrapped (Unwrapped a) ~ (), Unwrapped (Unwrapped b) ~ (),
    Rewrapped s t, Rewrapped t s,
    Rewrapped b a, Rewrapped a b,
    Rewrapped (Unwrapped b) (Unwrapped a),
    Rewrapped (Unwrapped a) (Unwrapped b)
  ) =>
  Prism s t a b
nothing =
  _Wrapped . _Nothing . _Unwrapped . _Unwrapped

{-#
  SPECIALIZE
  renothing ::
    Getter
      (Id1 ())
      (IdB a)
  #-}
renothing ::
  (
    Unwrapped t ~ Maybe x, Unwrapped (Unwrapped a) ~ (),
    Rewrapped t t, Rewrapped a a,
    Rewrapped (Unwrapped a) (Unwrapped a)
  ) =>
  Getter a t
renothing =
  re nothing

{-#
  SPECIALIZE
  id1 ::
    Iso
    (Id1 a)
    (Id1 a')
    a
    a'
  #-}
id1 ::
  (
    Rewrapped s t, Rewrapped t s,
    Rewrapped (Unwrapped s) (Unwrapped t),
    Rewrapped (Unwrapped t) (Unwrapped s)
  ) =>
  Iso s t (Unwrapped (Unwrapped s)) (Unwrapped (Unwrapped t))
id1 =
  _Wrapped . _Wrapped
