{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE MultiWayIf #-}
{-# LANGUAGE AllowAmbiguousTypes #-}
module Haskus.Number.Posit
   ( Posit (..)
   , PositKind (..)
   , PositK (..)
   , positKind
   , isZero
   , isInfinity
   , isPositive
   , isNegative
   , positAbs
   , PositEncoding (..)
   , PositFields (..)
   , positEncoding
   , positFields
   , positToRational
   , positFromRational
   , positApproxFactor
   , positDecimalError
   , positDecimalAccuracy
   , positBinaryError
   , positBinaryAccuracy
   , floatBinaryAccuracy
   )
where
import Haskus.Number.Int
import Haskus.Binary.Bits
import Haskus.Utils.Types
import Haskus.Utils.Tuple
import Haskus.Utils.Flow
import Data.Ratio
import qualified GHC.Real as Ratio
newtype Posit (nbits :: Nat) (es :: Nat) = Posit (IntN nbits)
instance
   ( Bits (IntN n)
   , FiniteBits (IntN n)
   , Ord (IntN n)
   , Num (IntN n)
   , KnownNat n
   , KnownNat es
   , Integral (IntN n)
   ) => Show (Posit n es)
   where
   show p = case positKind p of
      SomePosit Zero      -> "0"
      SomePosit Infinity  -> "Infinity"
      SomePosit (Value v) -> show (positToRational v)
data PositKind
   = ZeroK
   | InfinityK
   | NormalK
   deriving (Show,Eq)
data PositK k nbits es where
   Zero     :: PositK 'ZeroK nbits es
   Infinity :: PositK 'InfinityK nbits es
   Value    :: Posit nbits es -> PositK 'NormalK nbits es
data SomePosit n es where
   SomePosit :: PositK k n es -> SomePosit n es
type PositValue n es = PositK 'NormalK n es
positKind :: forall n es.
   ( Bits (IntN n)
   , KnownNat n
   , Eq (IntN n)
   ) => Posit n es -> SomePosit n es
positKind p
   | isZero p     = SomePosit Zero
   | isInfinity p = SomePosit Infinity
   | otherwise    = SomePosit (Value p)
isZero :: forall n es.
   ( Bits (IntN n)
   , Eq (IntN n)
   , KnownNat n
   ) => Posit n es -> Bool
{-# INLINABLE isZero #-}
isZero (Posit i) = i == zeroBits
isInfinity :: forall n es.
   ( Bits (IntN n)
   , Eq (IntN n)
   , KnownNat n
   ) => Posit n es -> Bool
{-# INLINABLE isInfinity #-}
isInfinity (Posit i) = i == bit (natValue @n - 1)
isPositive :: forall n es.
   ( Bits (IntN n)
   , Ord (IntN n)
   , KnownNat n
   ) => PositValue n es -> Bool
{-# INLINABLE isPositive #-}
isPositive (Value (Posit i)) = i > zeroBits
isNegative :: forall n es.
   ( Bits (IntN n)
   , Ord (IntN n)
   , KnownNat n
   ) => PositValue n es -> Bool
{-# INLINABLE isNegative #-}
isNegative (Value (Posit i)) = i < zeroBits
positAbs :: forall n es.
   ( Num (IntN n)
   , KnownNat n
   ) => PositValue n es -> PositValue n es
positAbs (Value (Posit i)) = Value (Posit (abs i))
data PositFields = PositFields
   { positNegative         :: Bool
   , positRegimeBitCount   :: Word
   , positExponentBitCount :: Word
   , positFractionBitCount :: Word
   , positRegime           :: Int
   , positExponent         :: Word
   , positFraction         :: Word
   }
   deriving (Show)
data PositEncoding
   = PositInfinity
   | PositZero
   | PositEncoding PositFields
   deriving (Show)
positEncoding :: forall n es.
   ( Bits (IntN n)
   , Ord (IntN n)
   , Num (IntN n)
   , KnownNat n
   , KnownNat es
   , Integral (IntN n)
   ) => Posit n es -> PositEncoding
positEncoding p = case positKind p of
   SomePosit Zero        -> PositZero
   SomePosit Infinity    -> PositInfinity
   SomePosit v@(Value _) -> PositEncoding (positFields v)
positFields :: forall n es.
   ( Bits (IntN n)
   , Ord (IntN n)
   , Num (IntN n)
   , KnownNat n
   , KnownNat es
   , Integral (IntN n)
   ) => PositValue n es -> PositFields
positFields p = PositFields
      { positNegative         = isNegative p
      , positRegimeBitCount   = rs
      , positExponentBitCount = es
      , positFractionBitCount = fs
      , positRegime           = regime
      , positExponent         = expo
      , positFraction         = frac
      }
   where
      
      Value (Posit v) = positAbs p
      (negativeRegime,regimeLen) =
         if v `testBit` (natValue @n - 2)
            
            then (False, countLeadingZeros (complement v `clearBit` (natValue @n - 1)) - 1)
            
            else (True, countLeadingZeros v - 1)
      regime = if negativeRegime
         then negate (fromIntegral regimeLen)
         else fromIntegral regimeLen - 1 
      
      rs = min (natValue @n - 1) (regimeLen + 1)
      
      es = min (natValue @n - rs - 1) (natValue @es)
      
      fs = natValue @n - es - rs - 1
      expo = fromIntegral (maskDyn es (v `shiftR` fs))
      frac = fromIntegral (maskDyn fs v)
positToRational :: forall n es.
   ( KnownNat n
   , KnownNat es
   , Eq (IntN n)
   , Bits (IntN n)
   , Integral (IntN n)
   ) => Posit n es -> Rational
positToRational p
   | isZero p     = 0 Ratio.:% 1
   | isInfinity p = Ratio.infinity
   | otherwise    = (fromIntegral useed ^^ r) * (2 ^^ e) * (1 + (f % fd))
      where
         fields = positFields (Value p)
         r      = positRegime fields
         e      = positExponent fields
         f      = fromIntegral (positFraction fields)
         fd     = 1 `shiftL` positFractionBitCount fields
         useed  = 1 `shiftL` (1 `shiftL` natValue @es) :: Integer 
positFromRational :: forall p n es.
   ( Posit n es ~ p
   , Num (IntN n)
   , Bits (IntN n)
   , KnownNat es
   , KnownNat n
   ) => Rational -> Posit n es
positFromRational x = if
      | x == 0              -> Posit 0
      | x == Ratio.infinity -> Posit (bit (natValue @n - 1))
      | otherwise           -> computeRegime
                              |> uncurry3 computeExponent
                              |> uncurry3 computeFraction
                              |> uncurry  computeRounding
                              |> computeSign
                              |> Posit
   where
      useed = fromIntegral (1 `shiftL` (1 `shiftL` es) :: Integer) 
      nbits = natValue @n
      es    = natValue @es
      
      
      
      
      computeRegime
         | absx >= 1 = regime111 absx 1 2
         | otherwise = regime000 absx 1
         where
            absx = abs x
            
            regime111 y p i
               | y >= useed && i < nbits = regime111 (y / useed) ((p `uncheckedShiftL` 1) .|. 1) (i+1)
               | otherwise               = (y, p `uncheckedShiftL` 1, i+1)
            
            
            regime000 y i
               | y < 1 && i <= nbits = regime000 (y*useed) (i+1)
               | i >= nbits          = (y,2,nbits+1)
               | otherwise           = (y,1,i+1)
      
      
      
      
      computeExponent
            | es == 0   = (,,)
            | otherwise = go (1 `shiftL` (es - 1))
         where
            go e y p i
               | i > nbits || e == 0 = (y,p,i)
               | y >= pow2e          = go (e `uncheckedShiftR` 1) (y / pow2e) ((p `uncheckedShiftL` 1) .|. 1) (i+1)
               | otherwise           = go (e `uncheckedShiftR` 1) y            (p `uncheckedShiftL` 1)        (i+1)
               where
                  pow2e = fromIntegral (1 `shiftL` e :: Integer)
      
      
      
      computeFraction y' = go (y'-1) 
         where
            go y p i
               | i > nbits = (y,p)
               | y <= 0    = (y, p `shiftL` (nbits+1-i)) 
               | y2 > 1    = go (y2-1) (p `shiftL` 1 + 1) (i+1)
               | otherwise = go y2     (p `shiftL` 1)     (i+1)
               where
                  y2 = 2*y
      
      
      computeRounding y p =
         let p' = p `uncheckedShiftR` 1
         in if | not (p `testBit` 0) -> p'                                     
               | y == 1 || y == 0    -> p' + (if p' `testBit` 0 then 1 else 0) 
               | otherwise           -> p' + 1                                 
      
      computeSign p
         | x < 0     = negate p
         | otherwise = p
positApproxFactor :: forall p n es.
   ( Posit n es ~ p
   , Num (IntN n)
   , Bits (IntN n)
   , Integral (IntN n)
   , KnownNat es
   , KnownNat n
   ) => Rational -> Double
positApproxFactor r = fromRational ((positToRational (positFromRational r ::  p)) / r)
positDecimalError :: forall p n es.
   ( Posit n es ~ p
   , Num (IntN n)
   , Bits (IntN n)
   , Integral (IntN n)
   , KnownNat es
   , KnownNat n
   ) => Rational -> Double
positDecimalError r = abs (logBase 10 (positApproxFactor @p r))
positDecimalAccuracy :: forall p n es.
   ( Posit n es ~ p
   , Num (IntN n)
   , Bits (IntN n)
   , Integral (IntN n)
   , KnownNat es
   , KnownNat n
   ) => Rational -> Double
positDecimalAccuracy r = -1 * logBase 10 (positDecimalError @p r)
positBinaryError :: forall p n es.
   ( Posit n es ~ p
   , Num (IntN n)
   , Bits (IntN n)
   , Integral (IntN n)
   , KnownNat es
   , KnownNat n
   ) => Rational -> Double
positBinaryError r = abs (logBase 2 (positApproxFactor @p r))
positBinaryAccuracy :: forall p n es.
   ( Posit n es ~ p
   , Num (IntN n)
   , Bits (IntN n)
   , Integral (IntN n)
   , KnownNat es
   , KnownNat n
   ) => Rational -> Double
positBinaryAccuracy r = -1 * logBase 2 (positBinaryError @p r)
floatBinaryAccuracy :: forall f.
   ( Fractional f
   , Real f
   ) => Rational -> Double
floatBinaryAccuracy r = -1 * logBase 2 floatError
   where
      floatApprox = fromRational (toRational (fromRational r :: f) / r)
      floatError  = abs (logBase 2 floatApprox)