{-# LANGUAGE ConstraintKinds         #-}
{-# LANGUAGE DataKinds               #-}
{-# LANGUAGE DefaultSignatures       #-}
{-# LANGUAGE FlexibleContexts        #-}
{-# LANGUAGE FlexibleInstances       #-}
{-# LANGUAGE KindSignatures          #-}
{-# LANGUAGE MultiParamTypeClasses   #-}
{-# LANGUAGE PolyKinds               #-}
{-# LANGUAGE ScopedTypeVariables     #-}
{-# LANGUAGE TypeApplications        #-}
{-# LANGUAGE TypeFamilies            #-}
{-# LANGUAGE TypeOperators           #-}
{-# LANGUAGE UndecidableInstances    #-}

-- The 'HasNormalForm' constraint on 'normalize' and 'denormalize' is
-- redundant as far as ghc is concerned (it's just 'unsafeCoerce' after all),
-- but essential for type safety of these two functions.
{-# OPTIONS_GHC -Wno-redundant-constraints #-}

module Data.Record.Generic.Transform (
    -- * Interpretation function
    Interpreted
  , Interpret(..)
    -- ** Working with the 'Interpreted' newtype wrapper
  , liftInterpreted
  , liftInterpretedA2
    -- * Normal form
    -- ** Existence
  , HasNormalForm
  , InterpretTo
  , IfEqual
    -- ** Construction
  , normalize
  , denormalize
    -- ** Specialized forms for the common case of a single type argument
  , Uninterpreted
  , DefaultInterpretation
  , normalize1
  , denormalize1
    -- ** Generalization of the default interpretation
  , StandardInterpretation(..)
  , toStandardInterpretation
  , fromStandardInterpretation
  ) where

import Data.Coerce
import Data.Kind
import Data.Proxy
import Data.SOP.BasicFunctors
import GHC.TypeLits
import Unsafe.Coerce (unsafeCoerce)

import Data.Record.Generic

{-------------------------------------------------------------------------------
  Interpretation function
-------------------------------------------------------------------------------}

type family Interpreted (d :: dom) (x :: Type) :: Type

newtype Interpret d x = Interpret (Interpreted d x)

{-------------------------------------------------------------------------------
  Working with the 'Interpreted' newtype wrapper
-------------------------------------------------------------------------------}

liftInterpreted ::
      (Interpreted dx x -> Interpreted dy y)
   -> (Interpret   dx x -> Interpret   dy y)
liftInterpreted :: forall {dom} {dom} (dx :: dom) x (dy :: dom) y.
(Interpreted dx x -> Interpreted dy y)
-> Interpret dx x -> Interpret dy y
liftInterpreted Interpreted dx x -> Interpreted dy y
f (Interpret Interpreted dx x
x) = forall {dom} (d :: dom) x. Interpreted d x -> Interpret d x
Interpret (Interpreted dx x -> Interpreted dy y
f Interpreted dx x
x)

liftInterpretedA2 ::
      Applicative m
   => (Interpreted dx x -> Interpreted dy y -> m (Interpreted dz z))
   -> (Interpret   dx x -> Interpret   dy y -> m (Interpret   dz z))
liftInterpretedA2 :: forall {dom} {dom} {dom} (m :: * -> *) (dx :: dom) x (dy :: dom) y
       (dz :: dom) z.
Applicative m =>
(Interpreted dx x -> Interpreted dy y -> m (Interpreted dz z))
-> Interpret dx x -> Interpret dy y -> m (Interpret dz z)
liftInterpretedA2 Interpreted dx x -> Interpreted dy y -> m (Interpreted dz z)
f (Interpret Interpreted dx x
x) (Interpret Interpreted dy y
y) = forall {dom} (d :: dom) x. Interpreted d x -> Interpret d x
Interpret forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Interpreted dx x -> Interpreted dy y -> m (Interpreted dz z)
f Interpreted dx x
x Interpreted dy y
y

{-------------------------------------------------------------------------------
  Normal forms
-------------------------------------------------------------------------------}

type HasNormalForm d x y = InterpretTo d (MetadataOf x) (MetadataOf y)

type family InterpretTo d xs ys :: Constraint where
  InterpretTo _ '[]             '[]             = ()
  InterpretTo d ('(f, x) ': xs) ('(f, y) ': ys) = IfEqual x (Interpreted d y)
                                                    (InterpretTo d xs ys)

type family IfEqual x y (r :: k) :: k where
  IfEqual actual   actual k = k
  IfEqual expected actual k = TypeError (
          'Text "Expected "
    ':<>: 'ShowType expected
    ':<>: 'Text " but got "
    ':<>: 'ShowType actual
    )

-- | Construct normal form
--
-- TODO: Documentation.
normalize ::
     HasNormalForm d x y
  => Proxy d
  -> Proxy y
  -> Rep I x -> Rep (Interpret d) y
normalize :: forall {dom} (d :: dom) x y.
HasNormalForm d x y =>
Proxy d -> Proxy y -> Rep I x -> Rep (Interpret d) y
normalize Proxy d
_ Proxy y
_ = forall a b. a -> b
unsafeCoerce

denormalize ::
     HasNormalForm d x y
  => Proxy d
  -> Proxy y
  -> Rep (Interpret d) y -> Rep I x
denormalize :: forall {dom} (d :: dom) x y.
HasNormalForm d x y =>
Proxy d -> Proxy y -> Rep (Interpret d) y -> Rep I x
denormalize Proxy d
_ Proxy y
_ = forall a b. a -> b
unsafeCoerce

{-------------------------------------------------------------------------------
  Specialized forms for the common case of a single type argument

  The tests in "Test.Record.Generic.Sanity.Transform" show an example with
  two arguments.
-------------------------------------------------------------------------------}

data Uninterpreted x

data DefaultInterpretation (f :: Type -> Type)

type instance Interpreted (DefaultInterpretation f) (Uninterpreted x) = f x

normalize1 :: forall d f x.
     HasNormalForm (d f) (x f) (x Uninterpreted)
  => Proxy d
  -> Rep I (x f) -> Rep (Interpret (d f)) (x Uninterpreted)
normalize1 :: forall {k} {dom} (d :: (k -> *) -> dom) (f :: k -> *)
       (x :: (k -> *) -> *).
HasNormalForm (d f) (x f) (x Uninterpreted) =>
Proxy d -> Rep I (x f) -> Rep (Interpret (d f)) (x Uninterpreted)
normalize1 Proxy d
_ = forall {dom} (d :: dom) x y.
HasNormalForm d x y =>
Proxy d -> Proxy y -> Rep I x -> Rep (Interpret d) y
normalize (forall {k} (t :: k). Proxy t
Proxy @(d f)) (forall {k} (t :: k). Proxy t
Proxy @(x Uninterpreted))

denormalize1 :: forall d f x.
     HasNormalForm (d f) (x f) (x Uninterpreted)
  => Proxy d
  -> Rep (Interpret (d f)) (x Uninterpreted) -> Rep I (x f)
denormalize1 :: forall {k} {dom} (d :: (k -> *) -> dom) (f :: k -> *)
       (x :: (k -> *) -> *).
HasNormalForm (d f) (x f) (x Uninterpreted) =>
Proxy d -> Rep (Interpret (d f)) (x Uninterpreted) -> Rep I (x f)
denormalize1 Proxy d
_ = forall {dom} (d :: dom) x y.
HasNormalForm d x y =>
Proxy d -> Proxy y -> Rep (Interpret d) y -> Rep I x
denormalize (forall {k} (t :: k). Proxy t
Proxy @(d f)) (forall {k} (t :: k). Proxy t
Proxy @(x Uninterpreted))

{-------------------------------------------------------------------------------
  Generalization of the default interpretation
-------------------------------------------------------------------------------}

class StandardInterpretation d f where
  standardInterpretation ::
       Proxy d
    -> ( Interpreted (d f) (Uninterpreted x) -> f x
       , f x -> Interpreted (d f) (Uninterpreted x)
       )

  default standardInterpretation ::
       Coercible (Interpreted (d f) (Uninterpreted x)) (f x)
    => Proxy d
    -> ( Interpreted (d f) (Uninterpreted x) -> f x
       , f x -> Interpreted (d f) (Uninterpreted x)
       )
  standardInterpretation Proxy d
_ = (coerce :: forall a b. Coercible a b => a -> b
coerce, coerce :: forall a b. Coercible a b => a -> b
coerce)

instance StandardInterpretation DefaultInterpretation f

toStandardInterpretation :: forall d f x.
     StandardInterpretation d f
  => Proxy d
  -> f x -> Interpret (d f) (Uninterpreted x)
toStandardInterpretation :: forall {k} {dom} (d :: (k -> *) -> dom) (f :: k -> *) (x :: k).
StandardInterpretation d f =>
Proxy d -> f x -> Interpret (d f) (Uninterpreted x)
toStandardInterpretation Proxy d
d f x
fx = forall {dom} (d :: dom) x. Interpreted d x -> Interpret d x
Interpret forall a b. (a -> b) -> a -> b
$
    forall a b. (a, b) -> b
snd (forall {k} {dom} (d :: (k -> *) -> dom) (f :: k -> *) (x :: k).
StandardInterpretation d f =>
Proxy d
-> (Interpreted (d f) (Uninterpreted x) -> f x,
    f x -> Interpreted (d f) (Uninterpreted x))
standardInterpretation Proxy d
d) f x
fx

fromStandardInterpretation :: forall d f x.
     StandardInterpretation d f
  => Proxy d
  -> Interpret (d f) (Uninterpreted x) -> f x
fromStandardInterpretation :: forall {k} {dom} (d :: (k -> *) -> dom) (f :: k -> *) (x :: k).
StandardInterpretation d f =>
Proxy d -> Interpret (d f) (Uninterpreted x) -> f x
fromStandardInterpretation Proxy d
d (Interpret Interpreted (d f) (Uninterpreted x)
fx) =
    forall a b. (a, b) -> a
fst (forall {k} {dom} (d :: (k -> *) -> dom) (f :: k -> *) (x :: k).
StandardInterpretation d f =>
Proxy d
-> (Interpreted (d f) (Uninterpreted x) -> f x,
    f x -> Interpreted (d f) (Uninterpreted x))
standardInterpretation Proxy d
d) Interpreted (d f) (Uninterpreted x)
fx