{-# LANGUAGE
    TypeApplications
  , ScopedTypeVariables
  , LambdaCase
#-}

module Atrophy.LongMultiplication where

import Data.WideWord.Word128
import Data.Word
import qualified Data.Primitive.Contiguous as Contiguous
import Data.Primitive.Contiguous (PrimArray, MutableSliced, Mutable)
import Control.Monad.ST.Strict (ST)
import Data.STRef.Strict (newSTRef, modifySTRef, readSTRef)
import Data.Bits
import Data.Foldable (for_)

{-# INLINE multiply256By128UpperBits #-}
multiply256By128UpperBits :: Word128 -> Word128 -> Word128 -> Word128
multiply256By128UpperBits :: Word128 -> Word128 -> Word128 -> Word128
multiply256By128UpperBits Word128
aHi Word128
aLo Word128
b =
  let
    -- Break a and b into little-endian 64-bit chunks
    aChunks :: PrimArray Word64
    aChunks :: PrimArray Word64
aChunks = Word64 -> Word64 -> Word64 -> Word64 -> PrimArray Word64
forall (arr :: * -> *) a.
(Contiguous arr, Element arr a) =>
a -> a -> a -> a -> arr a
Contiguous.quadrupleton
      (Word128 -> Word64
word128Lo64 Word128
aLo)
      (Word128 -> Word64
word128Hi64 Word128
aLo)
      (Word128 -> Word64
word128Lo64 Word128
aHi)
      (Word128 -> Word64
word128Hi64 Word128
aHi)
    bChunks :: PrimArray Word64
    bChunks :: PrimArray Word64
bChunks = Word64 -> Word64 -> PrimArray Word64
forall (arr :: * -> *) a.
(Contiguous arr, Element arr a) =>
a -> a -> arr a
Contiguous.doubleton
      (Word128 -> Word64
word128Lo64 Word128
b)
      (Word128 -> Word64
word128Hi64 Word128
b)

    -- Multiply b by a, one chunk of b at a time
    prod :: PrimArray Word64
    prod :: PrimArray Word64
prod = (forall s. ST s (Mutable PrimArray s Word64)) -> PrimArray Word64
forall (arr :: * -> *) a.
(Contiguous arr, Element arr a) =>
(forall s. ST s (Mutable arr s a)) -> arr a
Contiguous.create ((forall s. ST s (Mutable PrimArray s Word64)) -> PrimArray Word64)
-> (forall s. ST s (Mutable PrimArray s Word64))
-> PrimArray Word64
forall a b. (a -> b) -> a -> b
$ do
      MutablePrimArray s Word64
prod' <- Int -> Word64 -> ST s (Mutable PrimArray (PrimState (ST s)) Word64)
forall (arr :: * -> *) (m :: * -> *) b.
(Contiguous arr, PrimMonad m, Element arr b) =>
Int -> b -> m (Mutable arr (PrimState m) b)
Contiguous.replicateMut Int
6 Word64
0
      ((Int -> Word64 -> ST s ()) -> PrimArray Word64 -> ST s ())
-> PrimArray Word64 -> (Int -> Word64 -> ST s ()) -> ST s ()
forall a b c. (a -> b -> c) -> b -> a -> c
flip (Int -> Word64 -> ST s ()) -> PrimArray Word64 -> ST s ()
forall (arr :: * -> *) a (f :: * -> *) b.
(Contiguous arr, Element arr a, Applicative f) =>
(Int -> a -> f b) -> arr a -> f ()
Contiguous.itraverse_ PrimArray Word64
bChunks ((Int -> Word64 -> ST s ()) -> ST s ())
-> (Int -> Word64 -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
bIndex Word64
bDigit -> do
        Int
pSize <- Mutable PrimArray (PrimState (ST s)) Word64 -> ST s Int
forall (arr :: * -> *) (m :: * -> *) b.
(Contiguous arr, PrimMonad m, Element arr b) =>
Mutable arr (PrimState m) b -> m Int
Contiguous.sizeMut Mutable PrimArray (PrimState (ST s)) Word64
MutablePrimArray s Word64
prod'
        MutableSliced PrimArray s Word64
-> PrimArray Word64 -> Word64 -> ST s ()
forall s.
MutableSliced PrimArray s Word64
-> PrimArray Word64 -> Word64 -> ST s ()
multiply256By64Helper
            (Mutable PrimArray s Word64
-> Int -> Int -> MutableSliced PrimArray s Word64
forall (arr :: * -> *) a s.
(Contiguous arr, Element arr a) =>
Mutable arr s a -> Int -> Int -> MutableSliced arr s a
Contiguous.sliceMut Mutable PrimArray s Word64
MutablePrimArray s Word64
prod' Int
bIndex (Int
pSize Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
bIndex))
            PrimArray Word64
aChunks
            Word64
bDigit
      MutablePrimArray s Word64 -> ST s (MutablePrimArray s Word64)
forall (f :: * -> *) a. Applicative f => a -> f a
pure MutablePrimArray s Word64
prod'

  in Word128 :: Word64 -> Word64 -> Word128
Word128
    { word128Hi64 :: Word64
word128Hi64 = PrimArray Word64 -> Int -> Word64
forall (arr :: * -> *) b.
(Contiguous arr, Element arr b) =>
arr b -> Int -> b
Contiguous.index PrimArray Word64
prod Int
5
    , word128Lo64 :: Word64
word128Lo64 = PrimArray Word64 -> Int -> Word64
forall (arr :: * -> *) b.
(Contiguous arr, Element arr b) =>
arr b -> Int -> b
Contiguous.index PrimArray Word64
prod Int
4
    }

{-# INLINE multiply256By64Helper #-}
multiply256By64Helper :: forall s. MutableSliced PrimArray s Word64 -> PrimArray Word64 -> Word64 -> ST s ()
multiply256By64Helper :: MutableSliced PrimArray s Word64
-> PrimArray Word64 -> Word64 -> ST s ()
multiply256By64Helper MutableSliced PrimArray s Word64
_ PrimArray Word64
_ Word64
0 = () -> ST s ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
multiply256By64Helper MutableSliced PrimArray s Word64
prod PrimArray Word64
a Word64
b = do
  STRef s Word128
carry <- Word128 -> ST s (STRef s Word128)
forall a s. a -> ST s (STRef s a)
newSTRef Word128
0
  Int
productSize <- Mutable (Slice PrimArray) (PrimState (ST s)) Word64 -> ST s Int
forall (arr :: * -> *) (m :: * -> *) b.
(Contiguous arr, PrimMonad m, Element arr b) =>
Mutable arr (PrimState m) b -> m Int
Contiguous.sizeMut MutableSliced PrimArray s Word64
Mutable (Slice PrimArray) (PrimState (ST s)) Word64
prod
  let
    aSize :: Int
aSize = PrimArray Word64 -> Int
forall (arr :: * -> *) b.
(Contiguous arr, Element arr b) =>
arr b -> Int
Contiguous.size PrimArray Word64
a
    productLo :: MutableSliced PrimArray s Word64
    productLo :: MutableSliced PrimArray s Word64
productLo = Mutable (Slice PrimArray) s Word64
-> Int -> Int -> MutableSliced (Slice PrimArray) s Word64
forall (arr :: * -> *) a s.
(Contiguous arr, Element arr a) =>
Mutable arr s a -> Int -> Int -> MutableSliced arr s a
Contiguous.sliceMut MutableSliced PrimArray s Word64
Mutable (Slice PrimArray) s Word64
prod Int
0 Int
aSize
    productHi :: MutableSliced PrimArray s Word64
    productHi :: MutableSliced PrimArray s Word64
productHi = Mutable (Slice PrimArray) s Word64
-> Int -> Int -> MutableSliced (Slice PrimArray) s Word64
forall (arr :: * -> *) a s.
(Contiguous arr, Element arr a) =>
Mutable arr s a -> Int -> Int -> MutableSliced arr s a
Contiguous.sliceMut MutableSliced PrimArray s Word64
Mutable (Slice PrimArray) s Word64
prod Int
aSize (Int
productSize Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
aSize)
  -- Multiply each of the digits in a by b, adding them into the 'prod' value.
  -- We don't zero out prod, because we this will be called multiple times, so it probably contains a previous iteration's partial prod, and we're adding + carrying on top of it
  [Int] -> (Int -> ST s ()) -> ST s ()
forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
t a -> (a -> f b) -> f ()
for_ [Int
0..Int
aSize Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1] ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
i -> do
    Word64
p <- Mutable (Slice PrimArray) (PrimState (ST s)) Word64
-> Int -> ST s Word64
forall (arr :: * -> *) (m :: * -> *) b.
(Contiguous arr, PrimMonad m, Element arr b) =>
Mutable arr (PrimState m) b -> Int -> m b
Contiguous.read MutableSliced PrimArray s Word64
Mutable (Slice PrimArray) (PrimState (ST s)) Word64
productLo Int
i
    let aDigit :: Word64
aDigit = PrimArray Word64 -> Int -> Word64
forall (arr :: * -> *) b.
(Contiguous arr, Element arr b) =>
arr b -> Int -> b
Contiguous.index PrimArray Word64
a Int
i
    STRef s Word128 -> (Word128 -> Word128) -> ST s ()
forall s a. STRef s a -> (a -> a) -> ST s ()
modifySTRef STRef s Word128
carry ((Word128 -> Word128) -> ST s ())
-> (Word128 -> Word128) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Word128
x -> Word128
x
      Word128 -> Word128 -> Word128
forall a. Num a => a -> a -> a
+ Word64 -> Word64 -> Word128
Word128 Word64
0 Word64
p
      Word128 -> Word128 -> Word128
forall a. Num a => a -> a -> a
+ (Word64 -> Word64 -> Word128
Word128 Word64
0 Word64
aDigit Word128 -> Word128 -> Word128
forall a. Num a => a -> a -> a
* Word64 -> Word64 -> Word128
Word128 Word64
0 Word64
b)
    Mutable (Slice PrimArray) (PrimState (ST s)) Word64
-> Int -> Word64 -> ST s ()
forall (arr :: * -> *) (m :: * -> *) b.
(Contiguous arr, PrimMonad m, Element arr b) =>
Mutable arr (PrimState m) b -> Int -> b -> m ()
Contiguous.write MutableSliced PrimArray s Word64
Mutable (Slice PrimArray) (PrimState (ST s)) Word64
prod Int
i (Word64 -> ST s ()) -> (Word128 -> Word64) -> Word128 -> ST s ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Word128 -> Word64
word128Lo64 (Word128 -> ST s ()) -> ST s Word128 -> ST s ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< STRef s Word128 -> ST s Word128
forall s a. STRef s a -> ST s a
readSTRef STRef s Word128
carry
    STRef s Word128 -> (Word128 -> Word128) -> ST s ()
forall s a. STRef s a -> (a -> a) -> ST s ()
modifySTRef STRef s Word128
carry (Word128 -> Int -> Word128
forall a. Bits a => a -> Int -> a
`unsafeShiftR` Int
64)

  let productHiSize :: Int
productHiSize = Int
productSize Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
aSize
  [Int] -> (Int -> ST s ()) -> ST s ()
forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
t a -> (a -> f b) -> f ()
for_ [Int
0..Int
productHiSize Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1] ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
i -> do
    Word64
p <- Mutable (Slice PrimArray) (PrimState (ST s)) Word64
-> Int -> ST s Word64
forall (arr :: * -> *) (m :: * -> *) b.
(Contiguous arr, PrimMonad m, Element arr b) =>
Mutable arr (PrimState m) b -> Int -> m b
Contiguous.read MutableSliced PrimArray s Word64
Mutable (Slice PrimArray) (PrimState (ST s)) Word64
productHi Int
i
    STRef s Word128 -> (Word128 -> Word128) -> ST s ()
forall s a. STRef s a -> (a -> a) -> ST s ()
modifySTRef STRef s Word128
carry (Word128 -> Word128 -> Word128
forall a. Num a => a -> a -> a
+ Word64 -> Word64 -> Word128
Word128 Word64
0 Word64
p)
    Mutable (Slice PrimArray) (PrimState (ST s)) Word64
-> Int -> Word64 -> ST s ()
forall (arr :: * -> *) (m :: * -> *) b.
(Contiguous arr, PrimMonad m, Element arr b) =>
Mutable arr (PrimState m) b -> Int -> b -> m ()
Contiguous.write MutableSliced PrimArray s Word64
Mutable (Slice PrimArray) (PrimState (ST s)) Word64
productHi Int
i (Word64 -> ST s ()) -> (Word128 -> Word64) -> Word128 -> ST s ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Word128 -> Word64
word128Lo64 (Word128 -> ST s ()) -> ST s Word128 -> ST s ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< STRef s Word128 -> ST s Word128
forall s a. STRef s a -> ST s a
readSTRef STRef s Word128
carry
    STRef s Word128 -> (Word128 -> Word128) -> ST s ()
forall s a. STRef s a -> (a -> a) -> ST s ()
modifySTRef STRef s Word128
carry (Word128 -> Int -> Word128
forall a. Bits a => a -> Int -> a
`unsafeShiftR` Int
64)

  STRef s Word128 -> ST s Word128
forall s a. STRef s a -> ST s a
readSTRef STRef s Word128
carry ST s Word128 -> (Word128 -> ST s ()) -> ST s ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    Word128
0 -> () -> ST s ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
    Word128
_ -> [Char] -> ST s ()
forall a. HasCallStack => [Char] -> a
error [Char]
"carry overflow during multiplication!"

-- compute prod += a * b
{-# INLINE longMultiply #-}
longMultiply :: forall s. PrimArray Word64 -> Word64 -> Mutable PrimArray s Word64 -> ST s ()
longMultiply :: PrimArray Word64 -> Word64 -> Mutable PrimArray s Word64 -> ST s ()
longMultiply PrimArray Word64
a Word64
b Mutable PrimArray s Word64
prod = do
  MutableSlice PrimArray s Word64
prod' <- Mutable PrimArray (PrimState (ST s)) Word64
-> ST s (MutableSliced PrimArray (PrimState (ST s)) Word64)
forall (arr :: * -> *) (m :: * -> *) a.
(Contiguous arr, PrimMonad m, Element arr a) =>
Mutable arr (PrimState m) a
-> m (MutableSliced arr (PrimState m) a)
Contiguous.toSliceMut Mutable PrimArray s Word64
Mutable PrimArray (PrimState (ST s)) Word64
prod
  MutableSliced PrimArray s Word64
-> PrimArray Word64 -> Word64 -> ST s ()
forall s.
MutableSliced PrimArray s Word64
-> PrimArray Word64 -> Word64 -> ST s ()
multiply256By64Helper MutableSlice PrimArray s Word64
MutableSliced PrimArray s Word64
prod' PrimArray Word64
a Word64
b