{-# LANGUAGE MagicHash #-} {-# LANGUAGE CPP #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE KindSignatures #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE UnboxedTuples #-} {-# LANGUAGE BangPatterns #-} module Data.FastMult.Internal (FastMult(FastMult), FastMultSeq, simplify) where #include "MachDeps.h" import Prelude hiding (Integer) import GHC.Integer.GMP.Internals (BigNat, Integer(S#, Jp#, Jn#), sizeofBigNat#, timesBigNat, bigNatToWord, wordToBigNat, wordToBigNat2) import Data.Bits (FiniteBits, countLeadingZeros, finiteBitSize, xor, complement) import GHC.Base (timesWord2#, Word(W#), int2Word#, Int(I#), eqWord#, (>=#), negateInt#) import Data.Ord (comparing) import Data.Strict.List ( List((:!)) ) import qualified Data.Strict.List as Strict import GHC.TypeLits (Nat, KnownNat, natVal) import GHC.Conc.Sync (par) import Data.Proxy (Proxy) import Data.Word (Word) import Data.Ratio ((%)) import Data.Foldable (foldl') -- We use 'Word' here not 'Bool' because it unpacks bette. I'm not sure if this is an optimisation. newtype Sign = Sign Word deriving Show pattern Pos = Sign 0 pattern Neg = Sign (-1) multSigns :: Sign -> Sign -> Sign multSigns (Sign x) (Sign y) = Sign (x `xor` y) negateSign :: Sign -> Sign negateSign (Sign x) = Sign (complement x) data BigNatWithScale (n :: Nat) where BigNatWithScale :: KnownNat n => {-# UNPACK #-} !Word -> BigNat -> BigNatWithScale n -- Internal debug only: instance Show (BigNatWithScale n) where show (BigNatWithScale scale bigNat) = "(BigNatWithScale scale = " ++ show scale ++ ", num = " ++ show (Jp# bigNat) ++ ")" getBigNat :: BigNatWithScale (n :: Nat) -> BigNat getBigNat (BigNatWithScale _ x) = x {-| 'FastMult' is a Numeric type that can be used in any place a 'Num a' is required. It represents a standard integer using three components, which multiplied together represent the stored number: 1. The number's sign 2. An unsigned machine word. 3. A (possibly empty) list of 'BigNat's, which are the internal type for 'Integer's which are too large to fit in a machine word. Each 'BigNat' in the list has a scale. It's scale is the log base 2 of the number of words to store the machine word, minus 1. Note that we never store BigNats with length of only one machine word in this list, we instead convert them to an ordinary unsigned machine word and multiply them by item 2 in the list above. Only then if the result overflows we place them in this 'BigNat' list. This is a few examples of "MachineWords: Scale" * 2: 0 * 3: 1 * 4: 1 * 5: 2 * 6..8: 2 * 9..16: 3 * 17..32: 4 etc. Note this "scale" has the very nice property that multipling 'BigNat's of scale @x@ always results in a 'BigNat' of scale @x+1@. The list of 'BigNat's only ever contains one 'BigNat' of each "scale". As the size of 'BigNat's increases exponentially with scale, this list should always be relatively small. The 'BigNat' list is always sorted as well, smallest to largest. When we multiply two 'FastMult's, we merge the BigNat lists. This is basically a simple merge of sorted list, but with one significant change. Note that we said that the 'BigNat' list cannot contain two 'BigNat's of the same scale. So if find that a 'BigNat' in the left hand list of the multiplication is the same scale as a 'BigNat' in right hand list, we multiply these two 'BigNat's to create a 'BigNat' one "scale" larger. We then continue the merge, including this new BigNat. As a result, we only ever multiply numbers of the same "scale", that is, no more than double the length of one another. Why do we do this? Well, an ordinary product, say @product [1..1000000]@, towards the end of the list involves multiplications of a very large number by a machine word. These take @O(n)@ time. So the whole product takes @O(n^2)@ time. If we instead did the following: @ product x y = product x mid * product mid y mid = (x + y) `div` 2 (suitible base case here) @ We find that this runs a lot faster. The reason is that with this approach we're minimising products involving very large numbers, and importantly, multiplying two @n@ length numbers doesn't take @O(n^2)@ but more like @O(n*log(n))@ time. For this reason it's better to do a few multiplication of large numbers by large numbers, instead of lots of multiplications of large numbers by small numbers. But to do this I've had to redefine product. What if you don't want to change the algorithm, but just want to use one that's already been written, perhaps inefficiently. Well this is where 'FastMult' is useful. Instead of making the algorithm smarter, 'FastMult' just makes numbers smarter. The numbers themselves reorder the multiplications so you don't have too. As well as having the advantage of speeding up existing algorithms, 'FastMult' dynamically behaves differently based on what numbers it's actually multiplying and always maintains the invariant that multiplications will not be performed between numbers greater than twice the size each other. At this point I haven't mentioned the meaning of the `FastMult` type parameter @n@'. 'FastMult' can also add paralellism to your multiplication algorithms. However, sparking new GHC threads has a cost, so we only want to do it for large multiplications. Multiplications of @scale > n@ will spark a new thread, so @n = 0@ will spark new threads for any multiplication involving at least 3 machine words. This is probably too small, you can experiment with different numbers. Note that @n@ represents the scale, not size, so for example setting @n=4@ will only spark threads for multiplications involving at least 33 machine words. How well parallelism works (or if it works at all) hasn't been tested yet however. We include an ordinary machine word in the type as an optimisation for single machine word numbers. This is because multiplying 'BigNat's involves calling GMP using a C call, which is a large overhead for small multiplications. To use 'FastMult', all you have to do is import it's type, not it's implementation. If you're not interested in parallelism, just import 'FastMultSeq'. For example, just compare in GHCi: @ product [1..100000] @ and: @ product [1::FastMultSeq..100000] @ and you should find the latter completes much faster. Converting to and from 'Integer's can be done with the 'toInteger' and 'fromInteger' class methods from 'Integral' and 'Num' respectively. -} data FastMult (n :: Nat) where FastMult :: KnownNat n => {-# UNPACK #-} !Sign -> {-# UNPACK #-} !Word -> !(Strict.List (BigNatWithScale n)) -> FastMult n {-| A type synonym for a fully sequential 'FastMult'. The parameter is supposed to be 'WORD_MAX', but I couldn't find that defined, anyway what's important is that anything of scale smaller than @0xFFFFFFFF@ will be sequential, which is everything. -} type FastMultSeq = FastMult 0xFFFFFFFF data BigNatMultResult (n :: Nat) where ScaleLT :: BigNatMultResult n ScaleEQ :: (KnownNat n) => BigNatWithScale n -> BigNatMultResult n ScaleGT :: BigNatMultResult n singletonStrictList :: a -> Strict.List a singletonStrictList x = x :! Strict.Nil instance KnownNat n => Eq (FastMult n) where x == y = toInteger x == toInteger y instance KnownNat n => Ord (FastMult n) where x `compare` y = toInteger x `compare` toInteger y instance KnownNat n => Enum (FastMult n) where toEnum = fromIntegral fromEnum = fromIntegral instance KnownNat n => Num (FastMult n) where fromInteger = \case (S# prim_i) -> case (prim_i >=# 0#) of 1# -> FastMult Pos (W# (int2Word# prim_i)) Strict.Nil 0# -> FastMult Neg (W# (int2Word# (negateInt# prim_i))) Strict.Nil (Jp# x) -> fromBigNat Pos x (Jn# x) -> fromBigNat Neg x where fromBigNat :: Sign -> BigNat -> FastMult n fromBigNat sign x = case (W# (int2Word# (sizeofBigNat# x)) - 1) of 0 -> FastMult sign (W# (bigNatToWord x)) Strict.Nil size -> FastMult sign 1 (singletonStrictList (BigNatWithScale (logBase2Int size) x)) logBase2Int :: Word -> Word logBase2Int x = WORD_SIZE_IN_BITS - 1 - (fromIntegral (countLeadingZeros x)) (FastMult sign1 w1 l1) * (FastMult sign2 w2 l2) = let multBigNatWithScale :: forall n. BigNatWithScale n -> BigNatWithScale n -> BigNatMultResult n multBigNatWithScale (BigNatWithScale scale1 n1) (BigNatWithScale scale2 n2) = case (scale1 `compare` scale2) of EQ -> result `seqOrPar` (ScaleEQ (BigNatWithScale (scale1 + 1) result)) where result = n1 `timesBigNat` n2 seqOrPar = if scale1 <= maxSeq then seq else par LT -> ScaleLT GT -> ScaleGT where maxSeq = fromIntegral (natVal (undefined :: Proxy n)) signr = multSigns sign1 sign2 (# wu_prim, wl_prim #) = let !(W# w1_prim) = w1 !(W# w2_prim) = w2 in timesWord2# w1_prim w2_prim merge :: Strict.List (BigNatWithScale n) -> Strict.List (BigNatWithScale n) -> Strict.List (BigNatWithScale n) merge xl Strict.Nil = xl merge Strict.Nil yl = yl merge xl@(x:!xs) yl@(y:!ys) = case multBigNatWithScale x y of ScaleEQ result -> mergeWithCarry result xs ys ScaleLT -> x :! merge xs yl ScaleGT -> y :! merge xl ys mergeWithCarry :: BigNatWithScale n -> Strict.List (BigNatWithScale n) -> Strict.List (BigNatWithScale n) -> Strict.List (BigNatWithScale n) mergeWithCarry carry xl Strict.Nil = mergeOneCarry carry xl mergeWithCarry carry Strict.Nil yl = mergeOneCarry carry yl mergeWithCarry carry xl@(x:!xs) yl@(y:!ys) = case multBigNatWithScale x y of ScaleEQ result -> carry :! mergeWithCarry result xs ys ScaleLT -> contCarry x xs yl ScaleGT -> contCarry y ys xl where contCarry x xs yl = case multBigNatWithScale carry x of ScaleEQ result -> mergeWithCarry result xs yl ScaleLT -> carry :! x :! merge xs yl ScaleGT -> error $ "Carry should never be larger than first element. This should never happen. Report as bug.\n" ++ "Details:\n" ++ "carry =\n" ++ show carry ++ "\n" ++ "xl =\n" ++ show xl ++ "\n" ++ "yl =\n" ++ show yl ++ "\n" mergeOneCarry carry Strict.Nil = singletonStrictList carry mergeOneCarry carry xl@(x:!xs) = case multBigNatWithScale carry x of ScaleLT -> carry :! xl ScaleEQ result -> mergeOneCarry result xs ScaleGT -> error $ "Carry should never be larger than first element. This should never happen. Report as bug.\n" ++ "Details:\n" ++ "carry =\n" ++ show carry ++ "\n" ++ "xl =\n" ++ show xl ++ "\n" in case eqWord# wu_prim (int2Word# 0#) of 0# -> FastMult signr (W# wl_prim) (merge l1 l2) _ -> FastMult signr 1 (mergeWithCarry (BigNatWithScale 0 (wordToBigNat2 wu_prim wl_prim)) l1 l2) (+) = binaryViaInteger (+) (-) = binaryViaInteger (-) abs (FastMult _ word l) = FastMult Pos word l signum (FastMult Pos _ _) = FastMult Pos 1 Strict.Nil signum (FastMult Neg _ _) = FastMult Neg 1 Strict.Nil negate (FastMult sign word l) = FastMult (negateSign sign) word l binaryViaInteger f x y = fromInteger (toInteger x `f` toInteger y) unaryViaInteger f = fromInteger . f . toInteger instance KnownNat n => Real (FastMult n) where toRational x = (toInteger x) % 1 instance KnownNat n => Integral (FastMult n) where toInteger (FastMult sign (W# word_prim) l) = case sign of Pos -> Jp# result Neg -> Jn# result where result = foldl' (\x y -> x `timesBigNat` getBigNat y) (wordToBigNat word_prim) l x `quotRem` y = let (x_r, y_r) = (toInteger x `quotRem` toInteger y) in (fromInteger x_r, fromInteger y_r) instance KnownNat n => Show (FastMult n) where show = show . toInteger instance KnownNat n => Read (FastMult n) where readsPrec p s = map (\(x,y) -> (fromInteger x,y)) (readsPrec p s) {-| 'simplify' returns a 'FastMult' the same as it's argument but "simplified". To explain this, consider the following for @x :: FastMult@: @ f x = (show x, x + 1) @ It will multiply out @x@ twice, once for the addition, and once for 'show'. Note that the list of 'BigInt's in @x@ is generally a small number, as only one 'BigInt' is stored for each scale, and the sizes of scales increase exponentially, but there may be some multiplications required nevertheless. A better way to write this is as follows: @ f x = let y = simplify x in (show y, y + 1) @ This will ensure that @x@ is multiplied out only once. Unfortunately using 'simplify' stops your algorithms from being generic, so it might be better to define simplify as 'id' with a rewrite rule. I'll think about this. -} simplify :: KnownNat n => FastMult n -> FastMult n simplify = fromInteger . toInteger