{-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE KindSignatures #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE AllowAmbiguousTypes #-} {-# LANGUAGE ConstraintKinds #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE UndecidableInstances #-} {-# LANGUAGE ViewPatterns #-} module Haskus.Utils.EGADT where import Unsafe.Coerce import Haskus.Utils.Monad import Haskus.Utils.Variant import Haskus.Utils.VariantF import Haskus.Utils.Types -- $setup -- >>> :set -XDataKinds -- >>> :set -XTypeApplications -- >>> :set -XTypeOperators -- >>> :set -XFlexibleContexts -- >>> :set -XTypeFamilies -- >>> :set -XPatternSynonyms -- >>> :set -XDeriveFunctor -- >>> :set -XGADTs -- >>> :set -XPolyKinds -- >>> :set -XPartialTypeSignatures -- >>> -- >>> :{ -- >>> data LamF (ast :: Type -> Type) t where -- >>> LamF :: ( ast a -> ast b ) -> LamF ast ( a -> b ) -- >>> -- >>> data AppF ast t where -- >>> AppF :: ast ( a -> b ) -> ast a -> AppF ast b -- >>> -- >>> data VarF ast t where -- >>> VarF :: String -> VarF ast Int -- >>> -- >>> type AST a = EGADT '[LamF,AppF,VarF] a -- >>> -- >>> :} -- -- >>> let y = VF @(AST Int) (VarF "a") -- >>> :t y -- y :: EGADT '[LamF, AppF, VarF] Int -- -- >>> :{ -- >>> case y of -- >>> VF (VarF x) -> print x -- >>> _ -> putStrLn "Not a VarF" -- >>> :} -- "a" -- -- >>> :{ -- >>> f :: AST Int -> AST Int -- >>> f (VF (VarF x)) = VF (VarF "zz") -- >>> f _ = error "Unhandled case" -- >>> :} -- -- >>> let z = VF (AppF (VF (LamF f)) (VF (VarF "a"))) -- >>> :t z -- z :: EGADT '[LamF, AppF, VarF] Int -- -- | An EADT with an additional type parameter newtype EGADT fs t = EGADT (HVariantF fs (EGADT fs) t) newtype HVariantF (fs :: [ (k -> Type) -> ( k -> Type) ]) (ast :: k -> Type) (t :: k) = HVariantF (VariantF (ApplyAll ast fs) t) toHVariantAt :: forall i fs ast a . KnownNat i => (Index i fs) ast a -> VariantF (ApplyAll ast fs) a {-# INLINABLE toHVariantAt #-} toHVariantAt a = VariantF (Variant (natValue' @i) (unsafeCoerce a)) fromHVariantAt :: forall i fs ast a . KnownNat i => VariantF (ApplyAll ast fs) a -> Maybe ((Index i fs) ast a) {-# INLINABLE fromHVariantAt #-} fromHVariantAt (VariantF (Variant t a)) = do guard (t == natValue' @i) return (unsafeCoerce a) type instance HBase (EGADT xs) = HVariantF xs instance HFunctor (HVariantF xs) => HRecursive (EGADT xs) where hproject (EGADT a) = a instance HFunctor (HVariantF xs) => HCorecursive (EGADT xs) where hembed = EGADT type family f : f (EGADT fs) a -> EGADT fs a pattern VF x <- ( ( \ ( EGADT (HVariantF v) ) -> fromHVariantAt @(IndexOf f fs) @fs v ) -> Just x ) where VF x = EGADT (HVariantF (toHVariantAt @(IndexOf f fs) @fs x))