{-# LANGUAGE DataKinds #-}
{-# LANGUAGE EmptyCase #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}

{- |
Module      :  Data.Aeson.Schema.Utils.Sum
Maintainer  :  Brandon Chinn <brandon@leapyear.io>
Stability   :  experimental
Portability :  portable

The 'SumType' data type that represents a sum type consisting of types
specified in a type-level list.
-}
module Data.Aeson.Schema.Utils.Sum (
  SumType (..),
  fromSumType,
) where

import Control.Applicative ((<|>))
import Data.Aeson (FromJSON (..), ToJSON (..))
import Data.Kind (Constraint, Type)
import Data.Proxy (Proxy (..))
import GHC.TypeLits (ErrorMessage (..), Nat, TypeError, type (-))

{- | Represents a sum type.

 Loads the first type that successfully parses the JSON value.

 Example:

 @
 data Owl = Owl
 data Cat = Cat
 data Toad = Toad
 type Animal = SumType '[Owl, Cat, Toad]

 Here Owl                         :: Animal
 There (Here Cat)                 :: Animal
 There (There (Here Toad))        :: Animal

 {\- Fails at compile-time
 Here True                        :: Animal
 Here Cat                         :: Animal
 There (Here Owl)                 :: Animal
 There (There (There (Here Owl))) :: Animal
 -\}
 @
-}
data SumType (types :: [Type]) where
  Here :: forall x xs. x -> SumType (x ': xs)
  There :: forall x xs. SumType xs -> SumType (x ': xs)

deriving instance (Show x, Show (SumType xs)) => Show (SumType (x ': xs))
instance Show (SumType '[]) where
  show :: SumType '[] -> String
show = SumType '[] -> String
\case

deriving instance (Eq x, Eq (SumType xs)) => Eq (SumType (x ': xs))
instance Eq (SumType '[]) where
  SumType '[]
_ == :: SumType '[] -> SumType '[] -> Bool
== SumType '[]
_ = Bool
True

deriving instance (Ord x, Ord (SumType xs)) => Ord (SumType (x ': xs))
instance Ord (SumType '[]) where
  compare :: SumType '[] -> SumType '[] -> Ordering
compare SumType '[]
_ SumType '[]
_ = Ordering
EQ

instance (FromJSON x, FromJSON (SumType xs)) => FromJSON (SumType (x ': xs)) where
  parseJSON :: Value -> Parser (SumType (x : xs))
parseJSON Value
v = (x -> SumType (x : xs)
forall x (x :: [*]). x -> SumType (x : x)
Here (x -> SumType (x : xs)) -> Parser x -> Parser (SumType (x : xs))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Value -> Parser x
forall a. FromJSON a => Value -> Parser a
parseJSON Value
v) Parser (SumType (x : xs))
-> Parser (SumType (x : xs)) -> Parser (SumType (x : xs))
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> (SumType xs -> SumType (x : xs)
forall x (xs :: [*]). SumType xs -> SumType (x : xs)
There (SumType xs -> SumType (x : xs))
-> Parser (SumType xs) -> Parser (SumType (x : xs))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Value -> Parser (SumType xs)
forall a. FromJSON a => Value -> Parser a
parseJSON Value
v)

instance FromJSON (SumType '[]) where
  parseJSON :: Value -> Parser (SumType '[])
parseJSON Value
_ = String -> Parser (SumType '[])
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Could not parse sum type"

instance (ToJSON x, ToJSON (SumType xs)) => ToJSON (SumType (x ': xs)) where
  toJSON :: SumType (x : xs) -> Value
toJSON = \case
    Here x
x -> x -> Value
forall a. ToJSON a => a -> Value
toJSON x
x
    There SumType xs
xs -> SumType xs -> Value
forall a. ToJSON a => a -> Value
toJSON SumType xs
xs

instance ToJSON (SumType '[]) where
  toJSON :: SumType '[] -> Value
toJSON = SumType '[] -> Value
\case

{- Extracting sum type branches -}

class FromSumType (n :: Nat) (types :: [Type]) (x :: Type) where
  fromSumType' :: 'Just x ~ GetIndex n types => proxy1 n -> SumType types -> Maybe x

instance {-# OVERLAPPING #-} FromSumType 0 (x ': xs) x where
  fromSumType' :: proxy1 0 -> SumType (x : xs) -> Maybe x
fromSumType' proxy1 0
_ = \case
    Here x
x -> x -> Maybe x
forall a. a -> Maybe a
Just x
x
    There SumType xs
_ -> Maybe x
forall a. Maybe a
Nothing

instance
  {-# OVERLAPPABLE #-}
  ( FromSumType (n - 1) xs x
  , 'Just x ~ GetIndex (n - 1) xs
  ) =>
  FromSumType n (_x ': xs) x
  where
  fromSumType' :: proxy1 n -> SumType (_x : xs) -> Maybe x
fromSumType' proxy1 n
_ = \case
    Here x
_ -> Maybe x
forall a. Maybe a
Nothing
    There SumType xs
xs -> Proxy (n - 1) -> SumType xs -> Maybe x
forall (n :: Nat) (types :: [*]) x (proxy1 :: Nat -> *).
(FromSumType n types x, 'Just x ~ GetIndex n types) =>
proxy1 n -> SumType types -> Maybe x
fromSumType' (Proxy (n - 1)
forall k (t :: k). Proxy t
Proxy @(n - 1)) SumType xs
xs

{- | Extract a value from a 'SumType'

 Example:

 @
 type Animal = SumType '[Owl, Cat, Toad]
 let someAnimal = ... :: Animal

 fromSumType (Proxy :: Proxy 0) someAnimal :: Maybe Owl
 fromSumType (Proxy :: Proxy 1) someAnimal :: Maybe Cat
 fromSumType (Proxy :: Proxy 2) someAnimal :: Maybe Toad

 -- Compile-time error
 -- fromSumType (Proxy :: Proxy 3) someAnimal
 @
-}
fromSumType ::
  ( IsInRange n types
  , 'Just result ~ GetIndex n types
  , FromSumType n types result
  ) =>
  proxy n ->
  SumType types ->
  Maybe result
fromSumType :: proxy n -> SumType types -> Maybe result
fromSumType = proxy n -> SumType types -> Maybe result
forall (n :: Nat) (types :: [*]) x (proxy1 :: Nat -> *).
(FromSumType n types x, 'Just x ~ GetIndex n types) =>
proxy1 n -> SumType types -> Maybe x
fromSumType'

{- Helpers -}

type family IsInRange (n :: Nat) (xs :: [Type]) :: Constraint where
  IsInRange n xs =
    IsInRange'
      ( TypeError
          ( 'Text "Index "
              ':<>: 'ShowType n
              ':<>: 'Text " does not exist in list: "
              ':<>: 'ShowType xs
          )
      )
      n
      xs

type family IsInRange' typeErr (n :: Nat) (xs :: [Type]) :: Constraint where
  IsInRange' typeErr _ '[] = typeErr
  IsInRange' _ 0 (_ ': _) = ()
  IsInRange' typeErr n (_ ': xs) = IsInRange' typeErr (n - 1) xs

type family GetIndex (n :: Nat) (types :: [Type]) :: Maybe Type where
  GetIndex 0 (x ': xs) = 'Just x
  GetIndex _ '[] = 'Nothing
  GetIndex n (_ ': xs) = GetIndex (n - 1) xs