{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE EmptyCase #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE ViewPatterns #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.Presburger #-}

module Data.Type.Natural.Core
  ( SNat (.., Zero, Succ),
    ZeroOrSucc (..),
    viewNat,
    sNat,
    withKnownNat,
    (%+),
    (%-),
    (%*),
    (%^),
    sDiv,
    sMod,
    sLog2,
    (%<=?),
    sCmpNat,
    sCompare,
    Succ,
    S,
    sSucc,
    sS,
    Pred,
    sPred,
    Zero,
    One,
    sZero,
    sOne,
    Equality (..),
    type (===),
    (%~),
    sFlipOrdering,
    FlipOrdering,
    SOrdering (..),
    SBool (..),
    -- Re-exports
    module GHC.TypeNats,
  )
where

import Data.Coerce (coerce)
import Data.Proxy (Proxy)
import Data.Type.Equality
  ( TestEquality (..),
    gcastWith,
    type (:~:) (..),
    type (==),
  )
import Data.Type.Natural.Utils
import GHC.Exts (Proxy#, proxy#)
import GHC.TypeNats
import Math.NumberTheory.Logarithms (naturalLog2)
import Numeric.Natural (Natural)
import Proof.Propositional (Empty)
import Type.Reflection (Typeable)
import Unsafe.Coerce (unsafeCoerce)

-- | A singleton for type-level naturals
newtype SNat (n :: Nat) = SNat Natural
  deriving newtype (Int -> SNat n -> ShowS
[SNat n] -> ShowS
SNat n -> String
(Int -> SNat n -> ShowS)
-> (SNat n -> String) -> ([SNat n] -> ShowS) -> Show (SNat n)
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall (n :: Nat). Int -> SNat n -> ShowS
forall (n :: Nat). [SNat n] -> ShowS
forall (n :: Nat). SNat n -> String
showList :: [SNat n] -> ShowS
$cshowList :: forall (n :: Nat). [SNat n] -> ShowS
show :: SNat n -> String
$cshow :: forall (n :: Nat). SNat n -> String
showsPrec :: Int -> SNat n -> ShowS
$cshowsPrec :: forall (n :: Nat). Int -> SNat n -> ShowS
Show, SNat n -> SNat n -> Bool
(SNat n -> SNat n -> Bool)
-> (SNat n -> SNat n -> Bool) -> Eq (SNat n)
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
forall (n :: Nat). SNat n -> SNat n -> Bool
/= :: SNat n -> SNat n -> Bool
$c/= :: forall (n :: Nat). SNat n -> SNat n -> Bool
== :: SNat n -> SNat n -> Bool
$c== :: forall (n :: Nat). SNat n -> SNat n -> Bool
Eq, Eq (SNat n)
Eq (SNat n)
-> (SNat n -> SNat n -> Ordering)
-> (SNat n -> SNat n -> Bool)
-> (SNat n -> SNat n -> Bool)
-> (SNat n -> SNat n -> Bool)
-> (SNat n -> SNat n -> Bool)
-> (SNat n -> SNat n -> SNat n)
-> (SNat n -> SNat n -> SNat n)
-> Ord (SNat n)
SNat n -> SNat n -> Bool
SNat n -> SNat n -> Ordering
SNat n -> SNat n -> SNat n
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall (n :: Nat). Eq (SNat n)
forall (n :: Nat). SNat n -> SNat n -> Bool
forall (n :: Nat). SNat n -> SNat n -> Ordering
forall (n :: Nat). SNat n -> SNat n -> SNat n
min :: SNat n -> SNat n -> SNat n
$cmin :: forall (n :: Nat). SNat n -> SNat n -> SNat n
max :: SNat n -> SNat n -> SNat n
$cmax :: forall (n :: Nat). SNat n -> SNat n -> SNat n
>= :: SNat n -> SNat n -> Bool
$c>= :: forall (n :: Nat). SNat n -> SNat n -> Bool
> :: SNat n -> SNat n -> Bool
$c> :: forall (n :: Nat). SNat n -> SNat n -> Bool
<= :: SNat n -> SNat n -> Bool
$c<= :: forall (n :: Nat). SNat n -> SNat n -> Bool
< :: SNat n -> SNat n -> Bool
$c< :: forall (n :: Nat). SNat n -> SNat n -> Bool
compare :: SNat n -> SNat n -> Ordering
$ccompare :: forall (n :: Nat). SNat n -> SNat n -> Ordering
$cp1Ord :: forall (n :: Nat). Eq (SNat n)
Ord)

withKnownNat :: forall n r. SNat n -> (KnownNat n => r) -> r
withKnownNat :: SNat n -> (KnownNat n => r) -> r
withKnownNat (SNat Natural
n) KnownNat n => r
act =
  case Natural -> SomeNat
someNatVal Natural
n of
    SomeNat (Proxy n
_ :: Proxy m) ->
      (n :~: n) -> ((n ~ n) => r) -> r
forall k (a :: k) (b :: k) r. (a :~: b) -> ((a ~ b) => r) -> r
gcastWith ((() :~: ()) -> n :~: n
forall a b. a -> b
unsafeCoerce (() :~: ()
forall k (a :: k). a :~: a
Refl @()) :: n :~: m) KnownNat n => r
(n ~ n) => r
act

(%+) :: SNat n -> SNat m -> SNat (n + m)
%+ :: SNat n -> SNat m -> SNat (n + m)
(%+) = (Natural -> Natural -> Natural) -> SNat n -> SNat m -> SNat (n + m)
coerce ((Natural -> Natural -> Natural)
 -> SNat n -> SNat m -> SNat (n + m))
-> (Natural -> Natural -> Natural)
-> SNat n
-> SNat m
-> SNat (n + m)
forall a b. (a -> b) -> a -> b
$ Num Natural => Natural -> Natural -> Natural
forall a. Num a => a -> a -> a
(+) @Natural

(%-) :: SNat n -> SNat m -> SNat (n - m)
%- :: SNat n -> SNat m -> SNat (n - m)
(%-) = (Natural -> Natural -> Natural) -> SNat n -> SNat m -> SNat (n - m)
coerce ((Natural -> Natural -> Natural)
 -> SNat n -> SNat m -> SNat (n - m))
-> (Natural -> Natural -> Natural)
-> SNat n
-> SNat m
-> SNat (n - m)
forall a b. (a -> b) -> a -> b
$ (-) @Natural

(%*) :: SNat n -> SNat m -> SNat (n * m)
%* :: SNat n -> SNat m -> SNat (n * m)
(%*) = (Natural -> Natural -> Natural) -> SNat n -> SNat m -> SNat (n * m)
coerce ((Natural -> Natural -> Natural)
 -> SNat n -> SNat m -> SNat (n * m))
-> (Natural -> Natural -> Natural)
-> SNat n
-> SNat m
-> SNat (n * m)
forall a b. (a -> b) -> a -> b
$ Num Natural => Natural -> Natural -> Natural
forall a. Num a => a -> a -> a
(*) @Natural

sDiv :: SNat n -> SNat m -> SNat (Div n m)
sDiv :: SNat n -> SNat m -> SNat (Div n m)
sDiv = (Natural -> Natural -> Natural)
-> SNat n -> SNat m -> SNat (Div n m)
coerce ((Natural -> Natural -> Natural)
 -> SNat n -> SNat m -> SNat (Div n m))
-> (Natural -> Natural -> Natural)
-> SNat n
-> SNat m
-> SNat (Div n m)
forall a b. (a -> b) -> a -> b
$ Integral Natural => Natural -> Natural -> Natural
forall a. Integral a => a -> a -> a
div @Natural

sMod :: SNat n -> SNat m -> SNat (Mod n m)
sMod :: SNat n -> SNat m -> SNat (Mod n m)
sMod = (Natural -> Natural -> Natural)
-> SNat n -> SNat m -> SNat (Mod n m)
coerce ((Natural -> Natural -> Natural)
 -> SNat n -> SNat m -> SNat (Mod n m))
-> (Natural -> Natural -> Natural)
-> SNat n
-> SNat m
-> SNat (Mod n m)
forall a b. (a -> b) -> a -> b
$ Integral Natural => Natural -> Natural -> Natural
forall a. Integral a => a -> a -> a
mod @Natural

(%^) :: SNat n -> SNat m -> SNat (n ^ m)
%^ :: SNat n -> SNat m -> SNat (n ^ m)
(%^) = (Natural -> Natural -> Natural) -> SNat n -> SNat m -> SNat (n ^ m)
coerce ((Natural -> Natural -> Natural)
 -> SNat n -> SNat m -> SNat (n ^ m))
-> (Natural -> Natural -> Natural)
-> SNat n
-> SNat m
-> SNat (n ^ m)
forall a b. (a -> b) -> a -> b
$ (Num Natural, Integral Natural) => Natural -> Natural -> Natural
forall a b. (Num a, Integral b) => a -> b -> a
(^) @Natural @Natural

sLog2 :: SNat n -> SNat (Log2 n)
sLog2 :: SNat n -> SNat (Log2 n)
sLog2 = (Natural -> Natural) -> SNat n -> SNat (Log2 n)
coerce ((Natural -> Natural) -> SNat n -> SNat (Log2 n))
-> (Natural -> Natural) -> SNat n -> SNat (Log2 n)
forall a b. (a -> b) -> a -> b
$ (Integral Int, Num Natural) => Int -> Natural
forall a b. (Integral a, Num b) => a -> b
fromIntegral @Int @Natural (Int -> Natural) -> (Natural -> Int) -> Natural -> Natural
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Natural -> Int
naturalLog2

sNat :: forall n. KnownNat n => SNat n
sNat :: SNat n
sNat = Natural -> SNat n
forall (n :: Nat). Natural -> SNat n
SNat (Natural -> SNat n) -> Natural -> SNat n
forall a b. (a -> b) -> a -> b
$ Proxy# n -> Natural
forall (n :: Nat). KnownNat n => Proxy# n -> Natural
natVal' (Proxy# n
forall k (a :: k). Proxy# a
proxy# :: Proxy# n)

infixl 6 %+, %-

infixl 7 %*, `sDiv`, `sMod`

infixr 8 %^

instance TestEquality SNat where
  testEquality :: SNat a -> SNat b -> Maybe (a :~: b)
testEquality (SNat Natural
l) (SNat Natural
r) =
    if Natural
l Natural -> Natural -> Bool
forall a. Eq a => a -> a -> Bool
== Natural
r
      then (a :~: b) -> Maybe (a :~: b)
forall a. a -> Maybe a
Just a :~: b
forall k (x :: k) (y :: k). x :~: y
trustMe
      else Maybe (a :~: b)
forall a. Maybe a
Nothing

data Equality n m where
  Equal :: ((n == n) ~ 'True) => Equality n n
  NonEqual ::
    ((n === m) ~ 'False, (n == m) ~ 'False, Empty (n :~: m)) =>
    Equality n m

type family a === b where
  a === a = 'True
  _ === _ = 'False

infix 4 ===, %~

(%~) :: SNat l -> SNat r -> Equality l r
SNat Natural
l %~ :: SNat l -> SNat r -> Equality l r
%~ SNat Natural
r =
  if Natural
l Natural -> Natural -> Bool
forall a. Eq a => a -> a -> Bool
== Natural
r
    then Equality () () -> Equality l r
forall a b. a -> b
unsafeCoerce (((() == ()) ~ 'True) => Equality () ()
forall k (n :: k). ((n == n) ~ 'True) => Equality n n
Equal @())
    else Equality 0 1 -> Equality l r
forall a b. a -> b
unsafeCoerce (((0 === 1) ~ 'False, (0 == 1) ~ 'False, Empty (0 :~: 1)) =>
Equality 0 1
forall k (n :: k) (m :: k).
((n === m) ~ 'False, (n == m) ~ 'False, Empty (n :~: m)) =>
Equality n m
NonEqual @0 @1)

type Zero = 0

type One = 1

sZero :: SNat 0
sZero :: SNat 0
sZero = SNat 0
forall (n :: Nat). KnownNat n => SNat n
sNat

sOne :: SNat 1
sOne :: SNat 1
sOne = SNat 1
forall (n :: Nat). KnownNat n => SNat n
sNat

type Succ n = n + 1

type S n = Succ n

sSucc, sS :: SNat n -> SNat (Succ n)
sS :: SNat n -> SNat (Succ n)
sS = (SNat n -> SNat 1 -> SNat (Succ n)
forall (n :: Nat) (m :: Nat). SNat n -> SNat m -> SNat (n + m)
%+ SNat 1
sOne)
sSucc :: SNat n -> SNat (Succ n)
sSucc = SNat n -> SNat (Succ n)
forall (n :: Nat). SNat n -> SNat (Succ n)
sS

sPred :: SNat n -> SNat (Pred n)
sPred :: SNat n -> SNat (Pred n)
sPred = (SNat n -> SNat 1 -> SNat (Pred n)
forall (n :: Nat) (m :: Nat). SNat n -> SNat m -> SNat (n - m)
%- SNat 1
sOne)

type Pred n = n - 1

data ZeroOrSucc n where
  IsZero :: ZeroOrSucc 0
  IsSucc ::
    SNat n ->
    ZeroOrSucc (n + 1)

pattern Zero :: forall n. () => n ~ 0 => SNat n
pattern $bZero :: SNat n
$mZero :: forall r (n :: Nat). SNat n -> ((n ~ 0) => r) -> (Void# -> r) -> r
Zero <-
  (viewNat -> IsZero)
  where
    Zero = SNat n
SNat 0
sZero

pattern Succ :: forall n. () => forall n1. n ~ Succ n1 => SNat n1 -> SNat n
pattern $bSucc :: SNat n1 -> SNat n
$mSucc :: forall r (n :: Nat).
SNat n
-> (forall (n1 :: Nat). (n ~ Succ n1) => SNat n1 -> r)
-> (Void# -> r)
-> r
Succ n <-
  (viewNat -> IsSucc n)
  where
    Succ SNat n1
n = SNat n1 -> SNat (Succ n1)
forall (n :: Nat). SNat n -> SNat (Succ n)
sSucc SNat n1
n

{-# COMPLETE Zero, Succ #-}

viewNat :: forall n. SNat n -> ZeroOrSucc n
viewNat :: SNat n -> ZeroOrSucc n
viewNat SNat n
n =
  case SNat n
n SNat n -> SNat 0 -> Maybe (n :~: 0)
forall k (f :: k -> Type) (a :: k) (b :: k).
TestEquality f =>
f a -> f b -> Maybe (a :~: b)
`testEquality` KnownNat 0 => SNat 0
forall (n :: Nat). KnownNat n => SNat n
sNat @0 of
    Just n :~: 0
Refl -> ZeroOrSucc n
ZeroOrSucc 0
IsZero
    Maybe (n :~: 0)
Nothing -> ((1 <=? n) :~: 'True)
-> (((1 <=? n) ~ 'True) => ZeroOrSucc n) -> ZeroOrSucc n
forall k (a :: k) (b :: k) r. (a :~: b) -> ((a ~ b) => r) -> r
gcastWith ((1 <=? n) :~: 'True
forall k (x :: k) (y :: k). x :~: y
trustMe @(1 <=? n) @ 'True) ((((1 <=? n) ~ 'True) => ZeroOrSucc n) -> ZeroOrSucc n)
-> (((1 <=? n) ~ 'True) => ZeroOrSucc n) -> ZeroOrSucc n
forall a b. (a -> b) -> a -> b
$ SNat (n - 1) -> ZeroOrSucc ((n - 1) + 1)
forall (n :: Nat). SNat n -> ZeroOrSucc (n + 1)
IsSucc (SNat n -> SNat (n - 1)
forall (n :: Nat). SNat n -> SNat (Pred n)
sPred SNat n
n)

type family FlipOrdering ord where
  FlipOrdering 'LT = 'GT
  FlipOrdering 'GT = 'LT
  FlipOrdering 'EQ = 'EQ

sFlipOrdering :: SOrdering ord -> SOrdering (FlipOrdering ord)
sFlipOrdering :: SOrdering ord -> SOrdering (FlipOrdering ord)
sFlipOrdering SOrdering ord
SLT = SOrdering 'GT
SOrdering (FlipOrdering ord)
SGT
sFlipOrdering SOrdering ord
SEQ = SOrdering 'EQ
SOrdering (FlipOrdering ord)
SEQ
sFlipOrdering SOrdering ord
SGT = SOrdering 'LT
SOrdering (FlipOrdering ord)
SLT

data SOrdering (ord :: Ordering) where
  SLT :: SOrdering 'LT
  SEQ :: SOrdering 'EQ
  SGT :: SOrdering 'GT

deriving instance Show (SOrdering ord)

deriving instance Eq (SOrdering ord)

deriving instance Typeable SOrdering

data SBool (b :: Bool) where
  SFalse :: SBool 'False
  STrue :: SBool 'True

deriving instance Show (SBool ord)

deriving instance Eq (SBool ord)

deriving instance Typeable SBool

infix 4 %<=?

(%<=?) :: SNat n -> SNat m -> SBool (n <=? m)
SNat Natural
n %<=? :: SNat n -> SNat m -> SBool (n <=? m)
%<=? SNat Natural
m =
  if Natural
n Natural -> Natural -> Bool
forall a. Ord a => a -> a -> Bool
<= Natural
m
    then SBool 'True -> SBool (n <=? m)
forall a b. a -> b
unsafeCoerce SBool 'True
STrue
    else SBool 'False -> SBool (n <=? m)
forall a b. a -> b
unsafeCoerce SBool 'False
SFalse

sCmpNat, sCompare :: SNat n -> SNat m -> SOrdering (CmpNat n m)
sCompare :: SNat n -> SNat m -> SOrdering (CmpNat n m)
sCompare = SNat n -> SNat m -> SOrdering (CmpNat n m)
forall (n :: Nat) (m :: Nat).
SNat n -> SNat m -> SOrdering (CmpNat n m)
sCmpNat
sCmpNat :: SNat n -> SNat m -> SOrdering (CmpNat n m)
sCmpNat (SNat Natural
n) (SNat Natural
m) =
  case Natural -> Natural -> Ordering
forall a. Ord a => a -> a -> Ordering
compare Natural
n Natural
m of
    Ordering
LT -> SOrdering 'LT -> SOrdering (CmpNat n m)
forall a b. a -> b
unsafeCoerce SOrdering 'LT
SLT
    Ordering
EQ -> SOrdering 'EQ -> SOrdering (CmpNat n m)
forall a b. a -> b
unsafeCoerce SOrdering 'EQ
SEQ
    Ordering
GT -> SOrdering 'GT -> SOrdering (CmpNat n m)
forall a b. a -> b
unsafeCoerce SOrdering 'GT
SGT