{-# LANGUAGE
    CPP
  , TypeApplications
  , DataKinds
  , FlexibleContexts
  , DuplicateRecordFields
  , TypeFamilies
  , BangPatterns
  , NumericUnderscores
  , ScopedTypeVariables
  , DerivingStrategies
  , GeneralizedNewtypeDeriving
#-}

module Atrophy.Internal where

import Data.WideWord.Word128
import Data.Bits
import Atrophy.Internal.LongDivision
import GHC.Records
import Data.Word

newtype NonZero a = NonZero a
  deriving newtype (Integer -> NonZero a
NonZero a -> NonZero a
NonZero a -> NonZero a -> NonZero a
(NonZero a -> NonZero a -> NonZero a)
-> (NonZero a -> NonZero a -> NonZero a)
-> (NonZero a -> NonZero a -> NonZero a)
-> (NonZero a -> NonZero a)
-> (NonZero a -> NonZero a)
-> (NonZero a -> NonZero a)
-> (Integer -> NonZero a)
-> Num (NonZero a)
forall a. Num a => Integer -> NonZero a
forall a. Num a => NonZero a -> NonZero a
forall a. Num a => NonZero a -> NonZero a -> NonZero a
forall a.
(a -> a -> a)
-> (a -> a -> a)
-> (a -> a -> a)
-> (a -> a)
-> (a -> a)
-> (a -> a)
-> (Integer -> a)
-> Num a
fromInteger :: Integer -> NonZero a
$cfromInteger :: forall a. Num a => Integer -> NonZero a
signum :: NonZero a -> NonZero a
$csignum :: forall a. Num a => NonZero a -> NonZero a
abs :: NonZero a -> NonZero a
$cabs :: forall a. Num a => NonZero a -> NonZero a
negate :: NonZero a -> NonZero a
$cnegate :: forall a. Num a => NonZero a -> NonZero a
* :: NonZero a -> NonZero a -> NonZero a
$c* :: forall a. Num a => NonZero a -> NonZero a -> NonZero a
- :: NonZero a -> NonZero a -> NonZero a
$c- :: forall a. Num a => NonZero a -> NonZero a -> NonZero a
+ :: NonZero a -> NonZero a -> NonZero a
$c+ :: forall a. Num a => NonZero a -> NonZero a -> NonZero a
Num, Int -> NonZero a -> ShowS
[NonZero a] -> ShowS
NonZero a -> String
(Int -> NonZero a -> ShowS)
-> (NonZero a -> String)
-> ([NonZero a] -> ShowS)
-> Show (NonZero a)
forall a. Show a => Int -> NonZero a -> ShowS
forall a. Show a => [NonZero a] -> ShowS
forall a. Show a => NonZero a -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [NonZero a] -> ShowS
$cshowList :: forall a. Show a => [NonZero a] -> ShowS
show :: NonZero a -> String
$cshow :: forall a. Show a => NonZero a -> String
showsPrec :: Int -> NonZero a -> ShowS
$cshowsPrec :: forall a. Show a => Int -> NonZero a -> ShowS
Show)

instance (Bounded a, Num a) => Bounded (NonZero a) where
  minBound :: NonZero a
minBound = NonZero a
1
  maxBound :: NonZero a
maxBound = a -> NonZero a
forall a. a -> NonZero a
NonZero a
forall a. Bounded a => a
maxBound

{-# INLINE isPowerOf2 #-}
isPowerOf2 :: (Bits a, Num a) => a -> Bool
isPowerOf2 :: a -> Bool
isPowerOf2 a
x = (a
x a -> a -> a
forall a. Bits a => a -> a -> a
.&. (a
x a -> a -> a
forall a. Num a => a -> a -> a
- a
1)) a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
0

{-# INLINE new64 #-}
new64 :: NonZero Word64 -> StrengthReducedW64
new64 :: NonZero Word64 -> StrengthReducedW64
new64 (NonZero Word64
divi) =
  if Word64 -> Bool
forall a. (Bits a, Num a) => a -> Bool
isPowerOf2 Word64
divi
  then Word128 -> Word64 -> StrengthReducedW64
StrengthReducedW64 Word128
0 Word64
divi
  else
    let quotient :: Word128
quotient = Word64 -> Word128
divide128MaxBy64 (Word64 -> Word128) -> Word64 -> Word128
forall a b. (a -> b) -> a -> b
$ Word64 -> Word64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word64
divi
    in Word128 -> Word64 -> StrengthReducedW64
StrengthReducedW64 (Word128
quotient Word128 -> Word128 -> Word128
forall a. Num a => a -> a -> a
+ Word128
1) Word64
divi

{-# INLINE divRem64 #-}
divRem64 ::
  ( HasField "divisor" strRed a
  , HasField "multiplier" strRed Word128
  , Integral a
  , FiniteBits a
  ) => a -> strRed -> (a, a)
divRem64 :: a -> strRed -> (a, a)
divRem64 a
dividend strRed
divis =
  case strRed -> Word128
forall k (x :: k) r a. HasField x r a => r -> a
getField @"multiplier" strRed
divis of
    Word128
0 ->
      let
        quotient :: a
quotient = a
dividend a -> Int -> a
forall a. Bits a => a -> Int -> a
`unsafeShiftR` (a -> Int
forall b. FiniteBits b => b -> Int
countTrailingZeros (a -> Int) -> a -> Int
forall a b. (a -> b) -> a -> b
$ strRed -> a
forall k (x :: k) r a. HasField x r a => r -> a
getField @"divisor" strRed
divis)
        remainder :: a
remainder = a
dividend a -> a -> a
forall a. Bits a => a -> a -> a
.&. (strRed -> a
forall k (x :: k) r a. HasField x r a => r -> a
getField @"divisor" strRed
divis a -> a -> a
forall a. Num a => a -> a -> a
- a
1)
      in (a
quotient, a
remainder)
    Word128
multiplier' ->
      let
        numerator128 :: Word128
numerator128 = a -> Word128
forall a b. (Integral a, Num b) => a -> b
fromIntegral @_ @Word128 a
dividend
        multipliedHi :: Word128
multipliedHi = Word128
numerator128 Word128 -> Word128 -> Word128
forall a. Num a => a -> a -> a
* (Word128 -> Word128
upper128 Word128
multiplier')
        multipliedLo :: Word128
multipliedLo = Word128 -> Word128
upper128 (Word128
numerator128 Word128 -> Word128 -> Word128
forall a. Num a => a -> a -> a
* (Word128 -> Word128
lower128 Word128
multiplier'))

        quotient :: a
quotient = Word128 -> a
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word128 -> Word128
upper128 (Word128
multipliedHi Word128 -> Word128 -> Word128
forall a. Num a => a -> a -> a
+ Word128
multipliedLo))
        remainder :: a
remainder = a
dividend a -> a -> a
forall a. Num a => a -> a -> a
- a
quotient a -> a -> a
forall a. Num a => a -> a -> a
* strRed -> a
forall k (x :: k) r a. HasField x r a => r -> a
getField @"divisor" strRed
divis
      in (a
quotient, a
remainder)

{-# INLINE divRem #-}
{-# SPECIALIZE divRem :: Word32 -> StrengthReducedW32 -> (Word32, Word32) #-}
divRem :: forall strRed a b.
  ( HasField "divisor" strRed a
  , HasField "multiplier" strRed b
  , Integral a
  , FiniteBits a, Integral b, FiniteBits (Half b), Bits b) => a -> strRed -> (a, a)
divRem :: a -> strRed -> (a, a)
divRem a
dividend strRed
divis =
  case strRed -> b
forall k (x :: k) r a. HasField x r a => r -> a
getField @"multiplier" strRed
divis of
    b
0 ->
      let
        quotient :: a
quotient = a
dividend a -> Int -> a
forall a. Bits a => a -> Int -> a
`unsafeShiftR` (a -> Int
forall b. FiniteBits b => b -> Int
countTrailingZeros (a -> Int) -> a -> Int
forall a b. (a -> b) -> a -> b
$ strRed -> a
forall k (x :: k) r a. HasField x r a => r -> a
getField @"divisor" strRed
divis)
        remainder :: a
remainder = a
dividend a -> a -> a
forall a. Bits a => a -> a -> a
.&. (strRed -> a
forall k (x :: k) r a. HasField x r a => r -> a
getField @"divisor" strRed
divis a -> a -> a
forall a. Num a => a -> a -> a
- a
1)
      in (a
quotient, a
remainder)
    b
multiplier' ->
      let
        numerator64 :: b
numerator64 = a -> b
forall a b. (Integral a, Num b) => a -> b
fromIntegral @_ @b a
dividend
        multipliedHi :: b
multipliedHi = b
numerator64 b -> b -> b
forall a. Num a => a -> a -> a
* (b -> b
forall w. (Bits w, FiniteBits (Half w)) => w -> w
upperHalf b
multiplier')
        multipliedLo :: b
multipliedLo = b -> b
forall w. (Bits w, FiniteBits (Half w)) => w -> w
upperHalf (b
numerator64 b -> b -> b
forall a. Num a => a -> a -> a
* (b -> b
forall w. (FiniteBits (Half w), Bits w) => w -> w
lowerHalf b
multiplier'))

        quotient :: a
quotient = b -> a
forall a b. (Integral a, Num b) => a -> b
fromIntegral (b -> b
forall w. (Bits w, FiniteBits (Half w)) => w -> w
upperHalf (b
multipliedHi b -> b -> b
forall a. Num a => a -> a -> a
+ b
multipliedLo))
        remainder :: a
remainder = a
dividend a -> a -> a
forall a. Num a => a -> a -> a
- a
quotient a -> a -> a
forall a. Num a => a -> a -> a
* strRed -> a
forall k (x :: k) r a. HasField x r a => r -> a
getField @"divisor" strRed
divis
      in (a
quotient, a
remainder)

{-# INLINE new #-}
{-# SPECIALIZE new :: (Word64 -> Word32 -> StrengthReducedW32) -> NonZero Word32 -> StrengthReducedW32 #-}
new :: (Bits t, Integral t, Bounded (Multiplier t), Integral (Multiplier t)) =>((Multiplier t) -> t -> a) -> (NonZero t) -> a
new :: (Multiplier t -> t -> a) -> NonZero t -> a
new Multiplier t -> t -> a
con (NonZero t
divi) =
  if t -> Bool
forall a. (Bits a, Num a) => a -> Bool
isPowerOf2 t
divi
  then Multiplier t -> t -> a
con Multiplier t
0 t
divi
  else
    let quotient :: Multiplier t
quotient = Multiplier t
forall a. Bounded a => a
maxBound Multiplier t -> Multiplier t -> Multiplier t
forall a. Integral a => a -> a -> a
`div` t -> Multiplier t
forall a b. (Integral a, Num b) => a -> b
fromIntegral t
divi
    in Multiplier t -> t -> a
con (Multiplier t
quotient Multiplier t -> Multiplier t -> Multiplier t
forall a. Num a => a -> a -> a
+ Multiplier t
1) t
divi

{-# INLINE div64 #-}
{-# SPECIALIZE div64 :: Word64 -> StrengthReducedW64 -> Word64 #-}
div64 :: (HasField "divisor" r b, HasField "multiplier" r Word128,
 Integral b, FiniteBits b) =>
  b -> r -> b
div64 :: b -> r -> b
div64 b
a r
rhs = (b, b) -> b
forall a b. (a, b) -> a
fst ((b, b) -> b) -> (b, b) -> b
forall a b. (a -> b) -> a -> b
$ b -> r -> (b, b)
forall strRed a.
(HasField "divisor" strRed a, HasField "multiplier" strRed Word128,
 Integral a, FiniteBits a) =>
a -> strRed -> (a, a)
divRem64 b
a r
rhs

{-# INLINE rem64 #-}
{-# SPECIALIZE rem64 :: Word64 -> StrengthReducedW64 -> Word64 #-}
rem64 :: (HasField "divisor" r b, HasField "multiplier" r Word128,
 Integral b, FiniteBits b) =>
  b -> r -> b
rem64 :: b -> r -> b
rem64 b
a r
rhs = (b, b) -> b
forall a b. (a, b) -> b
snd ((b, b) -> b) -> (b, b) -> b
forall a b. (a -> b) -> a -> b
$ b -> r -> (b, b)
forall strRed a.
(HasField "divisor" strRed a, HasField "multiplier" strRed Word128,
 Integral a, FiniteBits a) =>
a -> strRed -> (a, a)
divRem64 b
a r
rhs

{-# INLINE div' #-}
{-# SPECIALIZE div' :: Word32 -> StrengthReducedW32 -> Word32 #-}
div' ::
  ( HasField "divisor" strRed b
  , HasField "multiplier" strRed w
  , Integral b, FiniteBits b,  Integral w, FiniteBits (Half w), Bits w) => b -> strRed -> b
div' :: b -> strRed -> b
div' b
a strRed
rhs = (b, b) -> b
forall a b. (a, b) -> a
fst ((b, b) -> b) -> (b, b) -> b
forall a b. (a -> b) -> a -> b
$ b -> strRed -> (b, b)
forall strRed a b.
(HasField "divisor" strRed a, HasField "multiplier" strRed b,
 Integral a, FiniteBits a, Integral b, FiniteBits (Half b),
 Bits b) =>
a -> strRed -> (a, a)
divRem b
a strRed
rhs

{-# INLINE rem' #-}
{-# SPECIALIZE rem' :: Word32 -> StrengthReducedW32 -> Word32 #-}
rem' ::
  ( HasField "divisor" strRed b
  , HasField "multiplier" strRed w
  , Integral b, FiniteBits b,  Integral w, FiniteBits (Half w), Bits w
  ) => b -> strRed -> b
rem' :: b -> strRed -> b
rem' b
a strRed
rhs = (b, b) -> b
forall a b. (a, b) -> b
snd ((b, b) -> b) -> (b, b) -> b
forall a b. (a -> b) -> a -> b
$ b -> strRed -> (b, b)
forall strRed a b.
(HasField "divisor" strRed a, HasField "multiplier" strRed b,
 Integral a, FiniteBits a, Integral b, FiniteBits (Half b),
 Bits b) =>
a -> strRed -> (a, a)
divRem b
a strRed
rhs

{-# INLINE lower128 #-}
lower128 :: Word128 -> Word128
lower128 :: Word128 -> Word128
lower128 (Word128 Word64
_hi Word64
low) = Word64 -> Word64 -> Word128
Word128 Word64
0 Word64
low

{-# INLINE upper128 #-}
upper128 :: Word128 -> Word128
upper128 :: Word128 -> Word128
upper128 (Word128 Word64
hi Word64
_low) = Word64 -> Word64 -> Word128
Word128 Word64
0 Word64
hi

{-# INLINE lowerHalf #-}
lowerHalf :: forall w. ( FiniteBits (Half w), Bits w) =>w -> w
lowerHalf :: w -> w
lowerHalf w
w = (w
w w -> Int -> w
forall a. Bits a => a -> Int -> a
`unsafeShiftL` Int
halfSize) w -> Int -> w
forall a. Bits a => a -> Int -> a
`unsafeShiftR` Int
halfSize
  where
  halfSize :: Int
halfSize = Half w -> Int
forall b. FiniteBits b => b -> Int
finiteBitSize @(Half w) Half w
forall a. Bits a => a
zeroBits

{-# INLINE upperHalf #-}
upperHalf :: forall w. ( Bits w, FiniteBits (Half w)) =>w -> w
upperHalf :: w -> w
upperHalf w
w = w
w w -> Int -> w
forall a. Bits a => a -> Int -> a
`unsafeShiftR` Int
halfSize
  where
  halfSize :: Int
halfSize = Half w -> Int
forall b. FiniteBits b => b -> Int
finiteBitSize @(Half w) Half w
forall a. Bits a => a
zeroBits

type family Multiplier a where
  Multiplier Word64 = Word128
  Multiplier Word32 = Word64
  Multiplier Word16 = Word32
  Multiplier Word8  = Word16

type family Half a where
  Half Word128 = Word64
  Half Word64 = Word32
  Half Word32 = Word16
  Half Word16 = Word8

data StrengthReducedW64 = StrengthReducedW64 { StrengthReducedW64 -> Word128
multiplier :: {-# UNPACK #-} !Word128, StrengthReducedW64 -> Word64
divisor :: {-# UNPACK #-} !Word64 }
data StrengthReducedW32 = StrengthReducedW32 { StrengthReducedW32 -> Word64
multiplier :: {-# UNPACK #-} !Word64,  StrengthReducedW32 -> Word32
divisor :: {-# UNPACK #-} !Word32 }
data StrengthReducedW16 = StrengthReducedW16 { StrengthReducedW16 -> Word32
multiplier :: {-# UNPACK #-} !Word32,  StrengthReducedW16 -> Word16
divisor :: {-# UNPACK #-} !Word16 }
data StrengthReducedW8  = StrengthReducedW7  { StrengthReducedW8 -> Word16
multiplier :: {-# UNPACK #-} !Word16,  StrengthReducedW8 -> Word8
divisor :: {-# UNPACK #-} !Word8  }

data StrengthReducedW128 = StrengthReducedW128
  { StrengthReducedW128 -> Word128
multiplierHi :: {-#UNPACK #-} !Word128
  , StrengthReducedW128 -> Word128
multiplierLo :: {-#UNPACK #-} !Word128
  , StrengthReducedW128 -> Word128
divisor      :: {-#UNPACK #-} !Word128
  }