{-# language DataKinds #-}
{-# language ExplicitForAll #-}
{-# language KindSignatures #-}
{-# language MagicHash #-}
{-# language RankNTypes #-}
{-# language ScopedTypeVariables #-}
{-# language TypeOperators #-}
{-# language UnboxedTuples #-}

module Arithmetic.Nat
  ( -- * Addition
    plus
  , plus#
    -- * Subtraction
  , monus
    -- * Division
  , divide
  , divideRoundingUp
    -- * Multiplication
  , times
    -- * Successor
  , succ
    -- * Compare
  , testEqual
  , testLessThan
  , testLessThanEqual
  , testZero
  , (=?)
  , (<?)
  , (<=?)
    -- * Constants
  , zero
  , one
  , two
  , three
  , constant
    -- * Unboxed Constants
  , zero#
    -- * Convert
  , demote
  , unlift
  , lift
  , with
  ) where

import Prelude hiding (succ)

import Arithmetic.Types
import Arithmetic.Unsafe ((:=:)(Eq), type (<=)(Lte))
import Arithmetic.Unsafe (Nat(Nat),Nat#(Nat#),type (<)(Lt))
import GHC.Exts (Proxy#,proxy#,(+#))
import GHC.TypeNats (type (+),type (-),Div,KnownNat,natVal')
import GHC.Int (Int(I#))

import qualified GHC.TypeNats as GHC

-- | Infix synonym of 'testLessThan'.
(<?) :: Nat a -> Nat b -> Maybe (a < b)
{-# inline (<?) #-}
<? :: forall (a :: Nat) (b :: Nat). Nat a -> Nat b -> Maybe (a < b)
(<?) = forall (a :: Nat) (b :: Nat). Nat a -> Nat b -> Maybe (a < b)
testLessThan

-- | Infix synonym of 'testLessThanEqual'.
(<=?) :: Nat a -> Nat b -> Maybe (a <= b)
{-# inline (<=?) #-}
<=? :: forall (a :: Nat) (b :: Nat). Nat a -> Nat b -> Maybe (a <= b)
(<=?) = forall (a :: Nat) (b :: Nat). Nat a -> Nat b -> Maybe (a <= b)
testLessThanEqual

-- | Infix synonym of 'testEqual'.
(=?) :: Nat a -> Nat b -> Maybe (a :=: b)
{-# inline (=?) #-}
=? :: forall (a :: Nat) (b :: Nat). Nat a -> Nat b -> Maybe (a :=: b)
(=?) = forall (a :: Nat) (b :: Nat). Nat a -> Nat b -> Maybe (a :=: b)
testEqual

-- | Is the first argument strictly less than the second
-- argument?
testLessThan :: Nat a -> Nat b -> Maybe (a < b)
{-# inline testLessThan #-}
testLessThan :: forall (a :: Nat) (b :: Nat). Nat a -> Nat b -> Maybe (a < b)
testLessThan (Nat Int
x) (Nat Int
y) = if Int
x forall a. Ord a => a -> a -> Bool
< Int
y
  then forall a. a -> Maybe a
Just forall (a :: Nat) (b :: Nat). a < b
Lt
  else forall a. Maybe a
Nothing

-- | Is the first argument less-than-or-equal-to the second
-- argument?
testLessThanEqual :: Nat a -> Nat b -> Maybe (a <= b)
{-# inline testLessThanEqual #-}
testLessThanEqual :: forall (a :: Nat) (b :: Nat). Nat a -> Nat b -> Maybe (a <= b)
testLessThanEqual (Nat Int
x) (Nat Int
y) = if Int
x forall a. Ord a => a -> a -> Bool
<= Int
y
  then forall a. a -> Maybe a
Just forall (a :: Nat) (b :: Nat). a <= b
Lte
  else forall a. Maybe a
Nothing

-- | Are the two arguments equal to one another?
testEqual :: Nat a -> Nat b -> Maybe (a :=: b)
{-# inline testEqual #-}
testEqual :: forall (a :: Nat) (b :: Nat). Nat a -> Nat b -> Maybe (a :=: b)
testEqual (Nat Int
x) (Nat Int
y) = if Int
x forall a. Eq a => a -> a -> Bool
== Int
y
  then forall a. a -> Maybe a
Just forall (a :: Nat) (b :: Nat). a :=: b
Eq
  else forall a. Maybe a
Nothing

-- | Is zero equal to this number or less than it?
testZero :: Nat a -> Either (0 :=: a) (0 < a)
{-# inline testZero #-}
testZero :: forall (a :: Nat). Nat a -> Either (0 :=: a) (0 < a)
testZero (Nat Int
x) = case Int
x of
  Int
0 -> forall a b. a -> Either a b
Left forall (a :: Nat) (b :: Nat). a :=: b
Eq
  Int
_ -> forall a b. b -> Either a b
Right forall (a :: Nat) (b :: Nat). a < b
Lt

-- | Add two numbers.
plus :: Nat a -> Nat b -> Nat (a + b)
{-# inline plus #-}
plus :: forall (a :: Nat) (b :: Nat). Nat a -> Nat b -> Nat (a + b)
plus (Nat Int
x) (Nat Int
y) = forall (n :: Nat). Int -> Nat n
Nat (Int
x forall a. Num a => a -> a -> a
+ Int
y)

-- | Variant of 'plus' for unboxed nats.
plus# :: Nat# a -> Nat# b -> Nat# (a + b)
{-# inline plus# #-}
plus# :: forall (a :: Nat) (b :: Nat). Nat# a -> Nat# b -> Nat# (a + b)
plus# (Nat# Int#
x) (Nat# Int#
y) = forall (n :: Nat). Int# -> Nat# n
Nat# (Int#
x Int# -> Int# -> Int#
+# Int#
y)

-- | Divide two numbers. Rounds down (towards zero)
divide :: Nat a -> Nat b -> Nat (Div a b)
{-# inline divide #-}
divide :: forall (a :: Nat) (b :: Nat). Nat a -> Nat b -> Nat (Div a b)
divide (Nat Int
x) (Nat Int
y) = forall (n :: Nat). Int -> Nat n
Nat (forall a. Integral a => a -> a -> a
div Int
x Int
y)

-- | Divide two numbers. Rounds up (away from zero)
divideRoundingUp :: Nat a -> Nat b -> Nat (Div (a - 1) b + 1)
{-# inline divideRoundingUp #-}
divideRoundingUp :: forall (a :: Nat) (b :: Nat).
Nat a -> Nat b -> Nat (Div (a - 1) b + 1)
divideRoundingUp (Nat Int
x) (Nat Int
y) =
  -- Implementation note. We must use div so that when x=0,
  -- the result is (-1) and not 0. Then when we add 1, we get 0.
  forall (n :: Nat). Int -> Nat n
Nat (Int
1 forall a. Num a => a -> a -> a
+ (forall a. Integral a => a -> a -> a
div (Int
x forall a. Num a => a -> a -> a
- Int
1) Int
y))

-- | Multiply two numbers.
times :: Nat a -> Nat b -> Nat (a GHC.* b)
{-# inline times #-}
times :: forall (a :: Nat) (b :: Nat). Nat a -> Nat b -> Nat (a * b)
times (Nat Int
x) (Nat Int
y) = forall (n :: Nat). Int -> Nat n
Nat (Int
x forall a. Num a => a -> a -> a
* Int
y)

-- | The successor of a number.
succ :: Nat a -> Nat (a + 1)
{-# inline succ #-}
succ :: forall (a :: Nat). Nat a -> Nat (a + 1)
succ Nat a
n = forall (a :: Nat) (b :: Nat). Nat a -> Nat b -> Nat (a + b)
plus Nat a
n Nat 1
one

-- | Subtract the second argument from the first argument.
monus :: Nat a -> Nat b -> Maybe (Difference a b)
{-# inline monus #-}
monus :: forall (a :: Nat) (b :: Nat).
Nat a -> Nat b -> Maybe (Difference a b)
monus (Nat Int
a) (Nat Int
b) = let c :: Int
c = Int
a forall a. Num a => a -> a -> a
- Int
b in if Int
c forall a. Ord a => a -> a -> Bool
>= Int
0
  then forall a. a -> Maybe a
Just (forall (a :: Nat) (b :: Nat) (c :: Nat).
Nat c -> ((c + b) :=: a) -> Difference a b
Difference (forall (n :: Nat). Int -> Nat n
Nat Int
c) forall (a :: Nat) (b :: Nat). a :=: b
Eq)
  else forall a. Maybe a
Nothing

-- | The number zero.
zero :: Nat 0
{-# inline zero #-}
zero :: Nat 0
zero = forall (n :: Nat). Int -> Nat n
Nat Int
0

-- | The number one.
one :: Nat 1
{-# inline one #-}
one :: Nat 1
one = forall (n :: Nat). Int -> Nat n
Nat Int
1

-- | The number two.
two :: Nat 2
{-# inline two #-}
two :: Nat 2
two = forall (n :: Nat). Int -> Nat n
Nat Int
2

-- | The number three.
three :: Nat 3
{-# inline three #-}
three :: Nat 3
three = forall (n :: Nat). Int -> Nat n
Nat Int
3

-- | Use GHC's built-in type-level arithmetic to create a witness
-- of a type-level number. This only reduces if the number is a
-- constant.
constant :: forall n. KnownNat n => Nat n
{-# inline constant #-}
constant :: forall (n :: Nat). KnownNat n => Nat n
constant = forall (n :: Nat). Int -> Nat n
Nat (forall a b. (Integral a, Num b) => a -> b
fromIntegral (forall (n :: Nat). KnownNat n => Proxy# n -> Nat
natVal' (forall {k} (a :: k). Proxy# a
proxy# :: Proxy# n)))

-- | The number zero. Unboxed.
zero# :: (# #) -> Nat# 0
zero# :: (# #) -> Nat# 0
zero# (# #)
_ = forall (n :: Nat). Int# -> Nat# n
Nat# Int#
0#

-- | Extract the 'Int' from a 'Nat'. This is intended to be used
-- at a boundary where a safe interface meets the unsafe primitives
-- on top of which it is built.
demote :: Nat n -> Int
{-# inline demote #-}
demote :: forall (n :: Nat). Nat n -> Int
demote (Nat Int
n) = Int
n

-- | Run a computation on a witness of a type-level number. The
-- argument 'Int' must be greater than or equal to zero. This is
-- not checked. Failure to upload this invariant will lead to a
-- segfault.
with :: Int -> (forall n. Nat n -> a) -> a
{-# inline with #-}
with :: forall a. Int -> (forall (n :: Nat). Nat n -> a) -> a
with Int
i forall (n :: Nat). Nat n -> a
f = forall (n :: Nat). Nat n -> a
f (forall (n :: Nat). Int -> Nat n
Nat Int
i)

unlift :: Nat n -> Nat# n
{-# inline unlift #-}
unlift :: forall (n :: Nat). Nat n -> Nat# n
unlift (Nat (I# Int#
i)) = forall (n :: Nat). Int# -> Nat# n
Nat# Int#
i

lift :: Nat# n -> Nat n
{-# inline lift #-}
lift :: forall (n :: Nat). Nat# n -> Nat n
lift (Nat# Int#
i) = forall (n :: Nat). Int -> Nat n
Nat (Int# -> Int
I# Int#
i)