{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE IncoherentInstances #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE TypeFamilies #-}

module Data.TypeNat.Nat (

    Nat(..)
  , IsNat(..)

  , LTE(..)
  , StrongLTE

  , Zero
  , One
  , Two
  , Three
  , Four
  , Five
  , Six
  , Seven
  , Eight
  , Nine
  , Ten

  ) where

import Data.Proxy
import GHC.Exts (Constraint)

-- | Natural numbers
data Nat = Z | S Nat

instance Eq Nat where
    Nat
Z == :: Nat -> Nat -> Bool
== Nat
Z = Bool
True
    (S Nat
n) == (S Nat
m) = Nat
n Nat -> Nat -> Bool
forall a. Eq a => a -> a -> Bool
== Nat
m
    Nat
_ == Nat
_ = Bool
False

instance Ord Nat where
    Nat
Z compare :: Nat -> Nat -> Ordering
`compare` Nat
Z = Ordering
EQ
    S Nat
n `compare` S Nat
m = Nat
n Nat -> Nat -> Ordering
forall a. Ord a => a -> a -> Ordering
`compare` Nat
m
    S Nat
n `compare` Nat
Z = Ordering
GT
    Nat
Z `compare` S Nat
m = Ordering
LT

type Zero = Z
type One = S Z
type Two = S One
type Three = S Two
type Four = S Three
type Five = S Four
type Six = S Five
type Seven = S Six
type Eight = S Seven
type Nine = S Eight
type Ten = S Nine

-- | Proof that a given type is a Nat.
--   With this fact, you can do type-directed computation.
class IsNat (n :: Nat) where
  natRecursion :: (forall m . b -> a m -> a (S m)) -> (b -> a Z) -> (b -> b) -> b -> a n

instance IsNat Z where
  natRecursion :: (forall (m :: Nat). b -> a m -> a ('S m))
-> (b -> a 'Z) -> (b -> b) -> b -> a 'Z
natRecursion forall (m :: Nat). b -> a m -> a ('S m)
_ b -> a 'Z
ifZ b -> b
_ = b -> a 'Z
ifZ

instance IsNat n => IsNat (S n) where
  natRecursion :: (forall (m :: Nat). b -> a m -> a ('S m))
-> (b -> a 'Z) -> (b -> b) -> b -> a ('S n)
natRecursion forall (m :: Nat). b -> a m -> a ('S m)
ifS b -> a 'Z
ifZ b -> b
reduce b
x = b -> a n -> a ('S n)
forall (m :: Nat). b -> a m -> a ('S m)
ifS b
x ((forall (m :: Nat). b -> a m -> a ('S m))
-> (b -> a 'Z) -> (b -> b) -> b -> a n
forall (n :: Nat) b (a :: Nat -> *).
IsNat n =>
(forall (m :: Nat). b -> a m -> a ('S m))
-> (b -> a 'Z) -> (b -> b) -> b -> a n
natRecursion forall (m :: Nat). b -> a m -> a ('S m)
ifS b -> a 'Z
ifZ b -> b
reduce (b -> b
reduce b
x))

-- | A constrint which includes LTE k m for every k <= m.
type family StrongLTE (n :: Nat) (m :: Nat) :: Constraint where
  StrongLTE Z m = LTE Z m
  StrongLTE (S n) m = (LTE (S n) m, StrongLTE n m)

-- | Nat @n@ is less than or equal to nat @m@.
--   Comes with functions to do type-directed computation for Nat-indexed
--   datatypes.
class LTE (n :: Nat) (m :: Nat) where
  lteInduction
    :: StrongLTE m l
    => Proxy l
    -> (forall k . LTE (S k) l => d k -> d (S k))
    -- ^ The parameter l is fixed by any call to lteInduction, but due to
    --   the StrongLTE m l constraint, we have LTE j l for every j <= m.
    --   This allows us to implement the nontrivial case in the
    --     @LTE p q => LTE p (S q)@
    --   instance, where we need to use this function to get @x :: d p@ and then
    --   again to get @f x :: d (S p)@. So long as @p@ and @S p@ are both
    --   less or equal to @l@, this can be done.
    -> d n
    -> d m
  lteRecursion :: (forall k . LTE n k => d (S k) -> d k) -> d m -> d n

instance LTE n n where
  lteInduction :: Proxy l
-> (forall (k :: Nat). LTE ('S k) l => d k -> d ('S k))
-> d n
-> d n
lteInduction Proxy l
_ forall (k :: Nat). LTE ('S k) l => d k -> d ('S k)
f d n
x = d n
x
  lteRecursion :: (forall (k :: Nat). LTE n k => d ('S k) -> d k) -> d n -> d n
lteRecursion forall (k :: Nat). LTE n k => d ('S k) -> d k
f d n
x = d n
x

instance LTE n m => LTE n (S m) where
  lteInduction :: Proxy l
-> (forall (k :: Nat). LTE ('S k) l => d k -> d ('S k))
-> d n
-> d ('S m)
lteInduction (Proxy l
proxy :: Proxy l) forall (k :: Nat). LTE ('S k) l => d k -> d ('S k)
f (d n
x :: d n) = d m -> d ('S m)
forall (k :: Nat). LTE ('S k) l => d k -> d ('S k)
f (Proxy l
-> (forall (k :: Nat). LTE ('S k) l => d k -> d ('S k))
-> d n
-> d m
forall (n :: Nat) (m :: Nat) (l :: Nat) (d :: Nat -> *).
(LTE n m, StrongLTE m l) =>
Proxy l
-> (forall (k :: Nat). LTE ('S k) l => d k -> d ('S k))
-> d n
-> d m
lteInduction Proxy l
proxy forall (k :: Nat). LTE ('S k) l => d k -> d ('S k)
f d n
x)
  lteRecursion :: (forall (k :: Nat). LTE n k => d ('S k) -> d k) -> d ('S m) -> d n
lteRecursion forall (k :: Nat). LTE n k => d ('S k) -> d k
f d ('S m)
x = (forall (k :: Nat). LTE n k => d ('S k) -> d k) -> d m -> d n
forall (n :: Nat) (m :: Nat) (d :: Nat -> *).
LTE n m =>
(forall (k :: Nat). LTE n k => d ('S k) -> d k) -> d m -> d n
lteRecursion forall (k :: Nat). LTE n k => d ('S k) -> d k
f (d ('S m) -> d m
forall (k :: Nat). LTE n k => d ('S k) -> d k
f d ('S m)
x)