{-# LANGUAGE BangPatterns         #-}
{-# LANGUAGE DeriveDataTypeable   #-}
{-# LANGUAGE DeriveGeneric        #-}
module Numeric.Peano where
import           Data.List                   (unfoldr)
import           Control.DeepSeq             (NFData (rnf))
import           Data.Data                   (Data, Typeable)
import           GHC.Generics                (Generic)
import           Numeric.Natural
import           Data.Ix
import           Data.Function
import           Text.Read
data Nat
    = Z
    | S Nat
    deriving (Eq,Generic,Data,Typeable)
foldrNat :: (a -> a) -> a -> Nat -> a
foldrNat f k = go
  where
    go Z     = k
    go (S n) = f (go n)
{-# INLINE foldrNat #-}
foldlNat' :: (a -> a) -> a -> Nat -> a
foldlNat' f = go
  where
    go !b Z = b
    go !b (S n) = go (f b) n
{-# INLINE foldlNat' #-}
instance Ord Nat where
    compare Z Z         = EQ
    compare (S n) (S m) = compare n m
    compare Z (S _)     = LT
    compare (S _) Z     = GT
    Z   <= _   = True
    S _ <= Z   = False
    S n <= S m = n <= m
    _ < Z = False
    n < S m = n <= m
    (>=) = flip (<=)
    (>) = flip (<)
    min Z _ = Z
    min _ Z = Z
    min (S n) (S m) = S (min n m)
    max Z m = m
    max n Z = n
    max (S n) (S m) = S (max n m)
instance Num Nat where
    n + m = foldrNat S m n
    n * m = foldrNat (m+) Z n
    abs = id
    signum Z = Z
    signum (S _) = S Z
    fromInteger n
        | n < 0 = error "cannot convert negative integers to Peano numbers"
        | otherwise = go n where
            go 0 = Z
            go m = S (go (m-1))
    Z   - _   = Z
    n   - Z   = n
    S n - S m = n - m
instance Bounded Nat where
    minBound = Z
    maxBound = fix S
instance Show Nat where
    showsPrec n = showsPrec n . toInteger
instance Read Nat where
    readPrec = fmap (fromIntegral :: Natural -> Nat) readPrec
instance NFData Nat where
    rnf Z     = ()
    rnf (S n) = rnf n
instance Real Nat where
    toRational = toRational . toInteger
instance Enum Nat where
    succ = S
    pred (S n) = n
    pred Z = error "pred called on zero nat"
    fromEnum = foldlNat' succ 0
    toEnum m
      | m < 0 = error "cannot convert negative number to Peano"
      | otherwise = go m
      where
        go 0 = Z
        go n = S (go (n - 1))
    enumFrom = iterate S
    enumFromTo n m = unfoldr f (n, S m - n)
      where
        f (_,Z) = Nothing
        f (e,S l) = Just (e, (S e, l))
    enumFromThen n m = iterate t n
      where
        ts Z mm = (+) mm
        ts (S nn) (S mm) = ts nn mm
        ts nn Z = subtract nn
        t = ts n m
    enumFromThenTo n m t = unfoldr f (n, jm)
      where
        ts (S nn) (S mm) = ts nn mm
        ts Z mm = (S t - n, (+) mm, mm)
        ts nn Z = (S n - t, subtract nn, nn)
        (jm,tf,tt) = ts n m
        td = subtract tt
        f (_,Z) = Nothing
        f (e,l@(S _)) = Just (e, (tf e, td l))
instance Integral Nat where
    toInteger = foldlNat' succ 0
    quotRem _ Z = (maxBound, error "divide by zero")
    quotRem x y = qr Z x y
      where
        qr q n m = go n m
          where
            go nn Z          = qr (S q) nn m
            go (S nn) (S mm) = go nn mm
            go Z (S _)       = (q, n)
    quot n m = go n where
      go = subt m where
        subt Z nn          = S (go nn)
        subt (S mm) (S nn) = subt mm nn
        subt (S _) Z       = Z
    rem _ Z = error "divide by zero"
    rem nn mm = r nn mm where
      r n m = go n m where
        go nnn Z           = r nnn m
        go (S nnn) (S mmm) = go nnn mmm
        go Z (S _)         = n
    div = quot
    mod = rem
    divMod = quotRem
instance Ix Nat where
    range = uncurry enumFromTo
    inRange = uncurry go where
      go (S _) _ Z         = False
      go Z y x             = x <= y
      go (S x) (S y) (S z) = go x y z
      go (S _) Z (S _)     = False
    index = uncurry go where
      go Z h i             = lim 0 h i
      go (S _) _ Z         = error "out of range"
      go (S l) (S h) (S i) = go l h i
      go (S _) Z (S _)     = error "out of range"
      lim _ Z (S _)      = error "out of range"
      lim !a (S n) (S m) = lim (a + 1) n m
      lim !a _ Z         = a