{-# LANGUAGE CPP #-} {-# LANGUAGE BangPatterns #-} {-# LANGUAGE DeriveDataTypeable #-} {-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE MagicHash #-} {-# LANGUAGE RankNTypes #-} #ifndef BITVEC_THREADSAFE module Data.Bit.F2Poly #else module Data.Bit.F2PolyTS #endif ( F2Poly , unF2Poly , toF2Poly ) where import Control.DeepSeq import Control.Exception import Control.Monad import Control.Monad.ST #ifndef BITVEC_THREADSAFE import Data.Bit.Immutable import Data.Bit.Internal import Data.Bit.Mutable #else import Data.Bit.ImmutableTS import Data.Bit.InternalTS import Data.Bit.MutableTS #endif import Data.Bit.Utils import Data.Bits import Data.Coerce import Data.Primitive.ByteArray import Data.Typeable import qualified Data.Vector.Unboxed as U import qualified Data.Vector.Unboxed.Mutable as MU import GHC.Generics #if UseIntegerGmp import qualified Data.Vector.Primitive as P import GHC.Exts import GHC.Integer.GMP.Internals import GHC.Integer.Logarithms import Unsafe.Coerce #endif -- | Binary polynomials of one variable, backed -- by an unboxed 'Data.Vector.Unboxed.Vector' 'Bit'. -- -- Polynomials are stored normalized, without leading zero coefficients. -- -- 'Ord' instance does not make much sense mathematically, -- it is defined only for the sake of 'Data.Set.Set', 'Data.Map.Map', etc. -- -- >>> :set -XBinaryLiterals -- >>> -- (1 + x) (1 + x + x^2) = 1 + x^3 (mod 2) -- >>> 0b11 * 0b111 :: F2Poly -- F2Poly {unF2Poly = [1,0,0,1]} newtype F2Poly = F2Poly { unF2Poly :: U.Vector Bit -- ^ Convert 'F2Poly' to a vector of coefficients -- (first element corresponds to a constant term). } deriving (Eq, Ord, Show, Typeable, Generic, NFData) -- | Make 'F2Poly' from a list of coefficients -- (first element corresponds to a constant term). toF2Poly :: U.Vector Bit -> F2Poly toF2Poly xs = F2Poly $ dropWhileEnd $ castFromWords $ cloneToWords xs -- | Valid 'F2Poly' has offset 0 and no trailing garbage. _isValid :: F2Poly -> Bool _isValid (F2Poly (BitVec o l arr)) = o == 0 && l == l' where l' = U.length $ dropWhileEnd $ BitVec 0 (sizeofByteArray arr `shiftL` 3) arr -- | Addition and multiplication are evaluated modulo 2. -- -- 'abs' = 'id' and 'signum' = 'const' 1. -- -- 'fromInteger' converts a binary polynomial, encoded as 'Integer', -- to 'F2Poly' encoding. instance Num F2Poly where (+) = coerce xorBits (-) = coerce xorBits negate = id abs = id signum = const (F2Poly (U.singleton (Bit True))) (*) = coerce ((dropWhileEnd .) . karatsuba) #if UseIntegerGmp fromInteger !n = case n of S# i# -> F2Poly $ BitVec 0 (wordSize - I# (word2Int# (clz# (int2Word# i#)))) $ fromBigNat $ wordToBigNat (int2Word# i#) Jp# bn# -> F2Poly $ BitVec 0 (I# (integerLog2# n) + 1) $ fromBigNat bn# Jn#{} -> error "F2Poly.fromInteger: argument must be non-negative" #else fromInteger = F2Poly . dropWhileEnd . integerToBits #endif {-# INLINE (+) #-} {-# INLINE (-) #-} {-# INLINE negate #-} {-# INLINE abs #-} {-# INLINE signum #-} {-# INLINE (*) #-} {-# INLINE fromInteger #-} instance Enum F2Poly where fromEnum = fromIntegral #if UseIntegerGmp toEnum !(I# i#) = F2Poly $ BitVec 0 (wordSize - I# (word2Int# (clz# (int2Word# i#)))) $ fromBigNat $ wordToBigNat (int2Word# i#) #else toEnum = fromIntegral #endif instance Real F2Poly where toRational = fromIntegral -- | 'toInteger' converts a binary polynomial, encoded as 'F2Poly', -- to 'Integer' encoding. instance Integral F2Poly where toInteger = bitsToInteger . unF2Poly quotRem (F2Poly xs) (F2Poly ys) = (F2Poly (dropWhileEnd qs), F2Poly (dropWhileEnd rs)) where (qs, rs) = quotRemBits xs ys divMod = quotRem mod = rem -- | Inputs must be valid for wrapping into F2Poly: no trailing garbage is allowed. xorBits :: U.Vector Bit -> U.Vector Bit -> U.Vector Bit xorBits (BitVec _ 0 _) ys = ys xorBits xs (BitVec _ 0 _) = xs #if UseIntegerGmp -- GMP has platform-dependent ASM implementations for mpn_xor_n, -- which are impossible to beat by native Haskell. xorBits (BitVec 0 lx xarr) (BitVec 0 ly yarr) = case lx `compare` ly of LT -> BitVec 0 ly zs EQ -> dropWhileEnd $ BitVec 0 (lx `min` (sizeofByteArray zs `shiftL` 3)) zs GT -> BitVec 0 lx zs where zs = fromBigNat (toBigNat xarr `xorBigNat` toBigNat yarr) #endif xorBits xs ys = dropWhileEnd $ runST $ do let lx = U.length xs ly = U.length ys (shorterLen, longerLen, longer) = if lx >= ly then (ly, lx, xs) else (lx, ly, ys) zs <- MU.replicate longerLen (Bit False) forM_ [0, wordSize .. shorterLen - 1] $ \i -> writeWord zs i (indexWord xs i `xor` indexWord ys i) U.unsafeCopy (MU.drop shorterLen zs) (U.drop shorterLen longer) U.unsafeFreeze zs -- | Must be >= 2 * wordSize. karatsubaThreshold :: Int karatsubaThreshold = 2048 karatsuba :: U.Vector Bit -> U.Vector Bit -> U.Vector Bit karatsuba xs ys | karatsubaThreshold < 2 * wordSize = error $ "karatsubaThreshold must be >= " ++ show (2 * wordSize) | xs == ys = sqrBits xs | lenXs <= karatsubaThreshold || lenYs <= karatsubaThreshold = mulBits xs ys | otherwise = runST $ do zs <- MU.unsafeNew lenZs forM_ [0, wordSize .. lenZs - 1] $ \k -> do let z0 = indexWord0 zs0 k z11 = indexWord0 zs11 (k - m) z10 = indexWord0 zs0 (k - m) z12 = indexWord0 zs2 (k - m) z2 = indexWord0 zs2 (k - 2 * m) writeWord zs k (z0 `xor` z11 `xor` z10 `xor` z12 `xor` z2) U.unsafeFreeze zs where lenXs = U.length xs lenYs = U.length ys lenZs = lenXs + lenYs - 1 m' = ((lenXs `min` lenYs) + 1) `quot` 2 m = m' - modWordSize m' xs0 = U.unsafeSlice 0 m xs xs1 = U.unsafeSlice m (lenXs - m) xs ys0 = U.unsafeSlice 0 m ys ys1 = U.unsafeSlice m (lenYs - m) ys xs01 = xorBits xs0 xs1 ys01 = xorBits ys0 ys1 zs0 = karatsuba xs0 ys0 zs2 = karatsuba xs1 ys1 zs11 = karatsuba xs01 ys01 indexWord0 :: U.Vector Bit -> Int -> Word indexWord0 bv i | i <= - wordSize = 0 | lenI <= 0 = 0 | i < 0, lenI >= wordSize = word0 | i < 0 = word0 .&. loMask lenI | lenI >= wordSize = word | otherwise = word .&. loMask lenI where lenI = U.length bv - i word = indexWord bv i word0 = indexWord bv 0 `unsafeShiftL` (- i) mulBits :: U.Vector Bit -> U.Vector Bit -> U.Vector Bit mulBits xs ys | lenXs == 0 || lenYs == 0 = U.empty | lenXs >= lenYs = mulBits' xs ys | otherwise = mulBits' ys xs where lenXs = U.length xs lenYs = U.length ys mulBits' :: U.Vector Bit -> U.Vector Bit -> U.Vector Bit mulBits' xs ys = runST $ do zs <- MU.replicate lenZs (Bit False) forM_ [0 .. lenYs - 1] $ \k -> when (unBit (U.unsafeIndex ys k)) $ zipInPlace xor xs (MU.unsafeSlice k (lenZs - k) zs) U.unsafeFreeze zs where lenXs = U.length xs lenYs = U.length ys lenZs = lenXs + lenYs - 1 sqrBits :: U.Vector Bit -> U.Vector Bit sqrBits xs = runST $ do let lenXs = U.length xs zs <- MU.replicate (mulWordSize (nWords lenXs `shiftL` 1)) (Bit False) forM_ [0, wordSize .. lenXs - 1] $ \i -> do let (z0, z1) = sparseBits (indexWord xs i) writeWord zs (i `shiftL` 1) z0 writeWord zs ((i `shiftL` 1) + wordSize) z1 U.unsafeFreeze zs quotRemBits :: U.Vector Bit -> U.Vector Bit -> (U.Vector Bit, U.Vector Bit) quotRemBits xs ys | U.null ys = throw DivideByZero | U.length xs < U.length ys = (U.empty, xs) | otherwise = runST $ do let lenXs = U.length xs lenYs = U.length ys lenQs = lenXs - lenYs + 1 qs <- MU.replicate lenQs (Bit False) rs <- MU.replicate lenXs (Bit False) U.unsafeCopy rs xs forM_ [lenQs - 1, lenQs - 2 .. 0] $ \i -> do Bit r <- MU.unsafeRead rs (lenYs - 1 + i) when r $ do MU.unsafeWrite qs i (Bit True) zipInPlace xor ys (MU.drop i rs) let rs' = MU.unsafeSlice 0 lenYs rs (,) <$> U.unsafeFreeze qs <*> U.unsafeFreeze rs' dropWhileEnd :: U.Vector Bit -> U.Vector Bit dropWhileEnd xs = U.unsafeSlice 0 (go (U.length xs)) xs where go n | n < wordSize = wordSize - countLeadingZeros (indexWord xs 0 .&. loMask n) | otherwise = case indexWord xs (n - wordSize) of 0 -> go (n - wordSize) w -> n - countLeadingZeros w #if UseIntegerGmp bitsToByteArray :: U.Vector Bit -> ByteArray# bitsToByteArray xs = arr where ys = if U.null xs then U.singleton 0 else cloneToWords xs !(P.Vector _ _ (ByteArray arr)) = toPrimVector ys fromBigNat :: BigNat -> ByteArray fromBigNat = unsafeCoerce -- fromBigNat (BN# arr) = ByteArray arr toBigNat :: ByteArray -> BigNat toBigNat = unsafeCoerce -- toBigNat (ByteArray arr) = BN# arr bitsToInteger :: U.Vector Bit -> Integer bitsToInteger xs = bigNatToInteger (BN# (bitsToByteArray xs)) #else integerToBits :: Integer -> U.Vector Bit integerToBits x = U.generate (bitLen x) (Bit . testBit x) bitLen :: Integer -> Int bitLen x = fst $ head $ dropWhile (\(_, b) -> x >= b) $ map (\a -> (a, 1 `shiftL` a)) $ map (1 `shiftL`) $ [lgWordSize..] bitsToInteger :: U.Vector Bit -> Integer bitsToInteger = U.ifoldl' (\acc i (Bit b) -> if b then acc `setBit` i else acc) 0 #endif