{-# LANGUAGE
TypeApplications
, ScopedTypeVariables
, LambdaCase
#-}
module Atrophy.LongMultiplication where
import Data.WideWord.Word128
import Data.Word
import qualified Data.Primitive.Contiguous as Contiguous
import Data.Primitive.Contiguous (PrimArray, MutableSliced, Mutable)
import Control.Monad.ST.Strict (ST)
import Data.STRef.Strict (newSTRef, modifySTRef, readSTRef)
import Data.Bits
import Data.Foldable (for_)
{-# INLINE multiply256By128UpperBits #-}
multiply256By128UpperBits :: Word128 -> Word128 -> Word128 -> Word128
multiply256By128UpperBits :: Word128 -> Word128 -> Word128 -> Word128
multiply256By128UpperBits Word128
aHi Word128
aLo Word128
b =
let
aChunks :: PrimArray Word64
aChunks :: PrimArray Word64
aChunks = Word64 -> Word64 -> Word64 -> Word64 -> PrimArray Word64
forall (arr :: * -> *) a.
(Contiguous arr, Element arr a) =>
a -> a -> a -> a -> arr a
Contiguous.quadrupleton
(Word128 -> Word64
word128Lo64 Word128
aLo)
(Word128 -> Word64
word128Hi64 Word128
aLo)
(Word128 -> Word64
word128Lo64 Word128
aHi)
(Word128 -> Word64
word128Hi64 Word128
aHi)
bChunks :: PrimArray Word64
bChunks :: PrimArray Word64
bChunks = Word64 -> Word64 -> PrimArray Word64
forall (arr :: * -> *) a.
(Contiguous arr, Element arr a) =>
a -> a -> arr a
Contiguous.doubleton
(Word128 -> Word64
word128Lo64 Word128
b)
(Word128 -> Word64
word128Hi64 Word128
b)
prod :: PrimArray Word64
prod :: PrimArray Word64
prod = (forall s. ST s (Mutable PrimArray s Word64)) -> PrimArray Word64
forall (arr :: * -> *) a.
(Contiguous arr, Element arr a) =>
(forall s. ST s (Mutable arr s a)) -> arr a
Contiguous.create ((forall s. ST s (Mutable PrimArray s Word64)) -> PrimArray Word64)
-> (forall s. ST s (Mutable PrimArray s Word64))
-> PrimArray Word64
forall a b. (a -> b) -> a -> b
$ do
MutablePrimArray s Word64
prod' <- Int -> Word64 -> ST s (Mutable PrimArray (PrimState (ST s)) Word64)
forall (arr :: * -> *) (m :: * -> *) b.
(Contiguous arr, PrimMonad m, Element arr b) =>
Int -> b -> m (Mutable arr (PrimState m) b)
Contiguous.replicateMut Int
6 Word64
0
((Int -> Word64 -> ST s ()) -> PrimArray Word64 -> ST s ())
-> PrimArray Word64 -> (Int -> Word64 -> ST s ()) -> ST s ()
forall a b c. (a -> b -> c) -> b -> a -> c
flip (Int -> Word64 -> ST s ()) -> PrimArray Word64 -> ST s ()
forall (arr :: * -> *) a (f :: * -> *) b.
(Contiguous arr, Element arr a, Applicative f) =>
(Int -> a -> f b) -> arr a -> f ()
Contiguous.itraverse_ PrimArray Word64
bChunks ((Int -> Word64 -> ST s ()) -> ST s ())
-> (Int -> Word64 -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
bIndex Word64
bDigit -> do
Int
pSize <- Mutable PrimArray (PrimState (ST s)) Word64 -> ST s Int
forall (arr :: * -> *) (m :: * -> *) b.
(Contiguous arr, PrimMonad m, Element arr b) =>
Mutable arr (PrimState m) b -> m Int
Contiguous.sizeMut Mutable PrimArray (PrimState (ST s)) Word64
MutablePrimArray s Word64
prod'
MutableSliced PrimArray s Word64
-> PrimArray Word64 -> Word64 -> ST s ()
forall s.
MutableSliced PrimArray s Word64
-> PrimArray Word64 -> Word64 -> ST s ()
multiply256By64Helper
(Mutable PrimArray s Word64
-> Int -> Int -> MutableSliced PrimArray s Word64
forall (arr :: * -> *) a s.
(Contiguous arr, Element arr a) =>
Mutable arr s a -> Int -> Int -> MutableSliced arr s a
Contiguous.sliceMut Mutable PrimArray s Word64
MutablePrimArray s Word64
prod' Int
bIndex (Int
pSize Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
bIndex))
PrimArray Word64
aChunks
Word64
bDigit
MutablePrimArray s Word64 -> ST s (MutablePrimArray s Word64)
forall (f :: * -> *) a. Applicative f => a -> f a
pure MutablePrimArray s Word64
prod'
in Word128 :: Word64 -> Word64 -> Word128
Word128
{ word128Hi64 :: Word64
word128Hi64 = PrimArray Word64 -> Int -> Word64
forall (arr :: * -> *) b.
(Contiguous arr, Element arr b) =>
arr b -> Int -> b
Contiguous.index PrimArray Word64
prod Int
5
, word128Lo64 :: Word64
word128Lo64 = PrimArray Word64 -> Int -> Word64
forall (arr :: * -> *) b.
(Contiguous arr, Element arr b) =>
arr b -> Int -> b
Contiguous.index PrimArray Word64
prod Int
4
}
{-# INLINE multiply256By64Helper #-}
multiply256By64Helper :: forall s. MutableSliced PrimArray s Word64 -> PrimArray Word64 -> Word64 -> ST s ()
multiply256By64Helper :: MutableSliced PrimArray s Word64
-> PrimArray Word64 -> Word64 -> ST s ()
multiply256By64Helper MutableSliced PrimArray s Word64
_ PrimArray Word64
_ Word64
0 = () -> ST s ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
multiply256By64Helper MutableSliced PrimArray s Word64
prod PrimArray Word64
a Word64
b = do
STRef s Word128
carry <- Word128 -> ST s (STRef s Word128)
forall a s. a -> ST s (STRef s a)
newSTRef Word128
0
Int
productSize <- Mutable (Slice PrimArray) (PrimState (ST s)) Word64 -> ST s Int
forall (arr :: * -> *) (m :: * -> *) b.
(Contiguous arr, PrimMonad m, Element arr b) =>
Mutable arr (PrimState m) b -> m Int
Contiguous.sizeMut MutableSliced PrimArray s Word64
Mutable (Slice PrimArray) (PrimState (ST s)) Word64
prod
let
aSize :: Int
aSize = PrimArray Word64 -> Int
forall (arr :: * -> *) b.
(Contiguous arr, Element arr b) =>
arr b -> Int
Contiguous.size PrimArray Word64
a
productLo :: MutableSliced PrimArray s Word64
productLo :: MutableSliced PrimArray s Word64
productLo = Mutable (Slice PrimArray) s Word64
-> Int -> Int -> MutableSliced (Slice PrimArray) s Word64
forall (arr :: * -> *) a s.
(Contiguous arr, Element arr a) =>
Mutable arr s a -> Int -> Int -> MutableSliced arr s a
Contiguous.sliceMut MutableSliced PrimArray s Word64
Mutable (Slice PrimArray) s Word64
prod Int
0 Int
aSize
productHi :: MutableSliced PrimArray s Word64
productHi :: MutableSliced PrimArray s Word64
productHi = Mutable (Slice PrimArray) s Word64
-> Int -> Int -> MutableSliced (Slice PrimArray) s Word64
forall (arr :: * -> *) a s.
(Contiguous arr, Element arr a) =>
Mutable arr s a -> Int -> Int -> MutableSliced arr s a
Contiguous.sliceMut MutableSliced PrimArray s Word64
Mutable (Slice PrimArray) s Word64
prod Int
aSize (Int
productSize Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
aSize)
[Int] -> (Int -> ST s ()) -> ST s ()
forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
t a -> (a -> f b) -> f ()
for_ [Int
0..Int
aSize Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1] ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
i -> do
Word64
p <- Mutable (Slice PrimArray) (PrimState (ST s)) Word64
-> Int -> ST s Word64
forall (arr :: * -> *) (m :: * -> *) b.
(Contiguous arr, PrimMonad m, Element arr b) =>
Mutable arr (PrimState m) b -> Int -> m b
Contiguous.read MutableSliced PrimArray s Word64
Mutable (Slice PrimArray) (PrimState (ST s)) Word64
productLo Int
i
let aDigit :: Word64
aDigit = PrimArray Word64 -> Int -> Word64
forall (arr :: * -> *) b.
(Contiguous arr, Element arr b) =>
arr b -> Int -> b
Contiguous.index PrimArray Word64
a Int
i
STRef s Word128 -> (Word128 -> Word128) -> ST s ()
forall s a. STRef s a -> (a -> a) -> ST s ()
modifySTRef STRef s Word128
carry ((Word128 -> Word128) -> ST s ())
-> (Word128 -> Word128) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Word128
x -> Word128
x
Word128 -> Word128 -> Word128
forall a. Num a => a -> a -> a
+ Word64 -> Word64 -> Word128
Word128 Word64
0 Word64
p
Word128 -> Word128 -> Word128
forall a. Num a => a -> a -> a
+ (Word64 -> Word64 -> Word128
Word128 Word64
0 Word64
aDigit Word128 -> Word128 -> Word128
forall a. Num a => a -> a -> a
* Word64 -> Word64 -> Word128
Word128 Word64
0 Word64
b)
Mutable (Slice PrimArray) (PrimState (ST s)) Word64
-> Int -> Word64 -> ST s ()
forall (arr :: * -> *) (m :: * -> *) b.
(Contiguous arr, PrimMonad m, Element arr b) =>
Mutable arr (PrimState m) b -> Int -> b -> m ()
Contiguous.write MutableSliced PrimArray s Word64
Mutable (Slice PrimArray) (PrimState (ST s)) Word64
prod Int
i (Word64 -> ST s ()) -> (Word128 -> Word64) -> Word128 -> ST s ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Word128 -> Word64
word128Lo64 (Word128 -> ST s ()) -> ST s Word128 -> ST s ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< STRef s Word128 -> ST s Word128
forall s a. STRef s a -> ST s a
readSTRef STRef s Word128
carry
STRef s Word128 -> (Word128 -> Word128) -> ST s ()
forall s a. STRef s a -> (a -> a) -> ST s ()
modifySTRef STRef s Word128
carry (Word128 -> Int -> Word128
forall a. Bits a => a -> Int -> a
`unsafeShiftR` Int
64)
let productHiSize :: Int
productHiSize = Int
productSize Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
aSize
[Int] -> (Int -> ST s ()) -> ST s ()
forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
t a -> (a -> f b) -> f ()
for_ [Int
0..Int
productHiSize Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1] ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
i -> do
Word64
p <- Mutable (Slice PrimArray) (PrimState (ST s)) Word64
-> Int -> ST s Word64
forall (arr :: * -> *) (m :: * -> *) b.
(Contiguous arr, PrimMonad m, Element arr b) =>
Mutable arr (PrimState m) b -> Int -> m b
Contiguous.read MutableSliced PrimArray s Word64
Mutable (Slice PrimArray) (PrimState (ST s)) Word64
productHi Int
i
STRef s Word128 -> (Word128 -> Word128) -> ST s ()
forall s a. STRef s a -> (a -> a) -> ST s ()
modifySTRef STRef s Word128
carry (Word128 -> Word128 -> Word128
forall a. Num a => a -> a -> a
+ Word64 -> Word64 -> Word128
Word128 Word64
0 Word64
p)
Mutable (Slice PrimArray) (PrimState (ST s)) Word64
-> Int -> Word64 -> ST s ()
forall (arr :: * -> *) (m :: * -> *) b.
(Contiguous arr, PrimMonad m, Element arr b) =>
Mutable arr (PrimState m) b -> Int -> b -> m ()
Contiguous.write MutableSliced PrimArray s Word64
Mutable (Slice PrimArray) (PrimState (ST s)) Word64
productHi Int
i (Word64 -> ST s ()) -> (Word128 -> Word64) -> Word128 -> ST s ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Word128 -> Word64
word128Lo64 (Word128 -> ST s ()) -> ST s Word128 -> ST s ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< STRef s Word128 -> ST s Word128
forall s a. STRef s a -> ST s a
readSTRef STRef s Word128
carry
STRef s Word128 -> (Word128 -> Word128) -> ST s ()
forall s a. STRef s a -> (a -> a) -> ST s ()
modifySTRef STRef s Word128
carry (Word128 -> Int -> Word128
forall a. Bits a => a -> Int -> a
`unsafeShiftR` Int
64)
STRef s Word128 -> ST s Word128
forall s a. STRef s a -> ST s a
readSTRef STRef s Word128
carry ST s Word128 -> (Word128 -> ST s ()) -> ST s ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
Word128
0 -> () -> ST s ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
Word128
_ -> [Char] -> ST s ()
forall a. HasCallStack => [Char] -> a
error [Char]
"carry overflow during multiplication!"
{-# INLINE longMultiply #-}
longMultiply :: forall s. PrimArray Word64 -> Word64 -> Mutable PrimArray s Word64 -> ST s ()
longMultiply :: PrimArray Word64 -> Word64 -> Mutable PrimArray s Word64 -> ST s ()
longMultiply PrimArray Word64
a Word64
b Mutable PrimArray s Word64
prod = do
MutableSlice PrimArray s Word64
prod' <- Mutable PrimArray (PrimState (ST s)) Word64
-> ST s (MutableSliced PrimArray (PrimState (ST s)) Word64)
forall (arr :: * -> *) (m :: * -> *) a.
(Contiguous arr, PrimMonad m, Element arr a) =>
Mutable arr (PrimState m) a
-> m (MutableSliced arr (PrimState m) a)
Contiguous.toSliceMut Mutable PrimArray s Word64
Mutable PrimArray (PrimState (ST s)) Word64
prod
MutableSliced PrimArray s Word64
-> PrimArray Word64 -> Word64 -> ST s ()
forall s.
MutableSliced PrimArray s Word64
-> PrimArray Word64 -> Word64 -> ST s ()
multiply256By64Helper MutableSlice PrimArray s Word64
MutableSliced PrimArray s Word64
prod' PrimArray Word64
a Word64
b