module Data.Digest.WebMoney.Algebra where
import Data.Bits (Bits, bitSize, shiftL, shiftR, testBit, (.&.),
(.|.))
import Data.Int (Int32, Int64)
import Data.Word (Word32, Word64)
import Control.Lens (ix, (&), (.~))
import Data.Vector (Vector, singleton, (!))
import qualified Data.Vector as V (init, last, length, null, replicate, take,
(++))
longMask :: Int64
longMask = 0xFFFFFFFF
intSize :: Int
intSize = 32
logicalShiftR :: Integral a => a -> Int -> a
logicalShiftR x i = fromIntegral ((fromIntegral x :: Word64) `shiftR` i)
logicalShiftRight :: Int32 -> Int -> Int32
logicalShiftRight x i = fromIntegral ((fromIntegral x :: Word32) `shiftR` i)
getBitsNumber :: Bits a => a -> Int
getBitsNumber x = intSize numberOfLeadingZeros x
numberOfLeadingZeros :: Bits a => a -> Int
numberOfLeadingZeros x = length $ takeWhile (not . testBit x) [size 1, size 2 .. 0]
where size = bitSize x
getBitsCount :: (Bits a, Num a) => Vector a -> Int
getBitsCount xs = ( vLenght 1 ) * intSize + getBitsNumber ( xs ! (vLenght 1) )
where vLenght = significance xs
compareLists :: Vector Int32 -> Vector Int32 -> Ordering
compareLists lhs rhs
| lhsLenght > rhsLenght = GT
| lhsLenght < rhsLenght = LT
| otherwise = comp (V.take lhsLenght lhs) (V.take lhsLenght rhs)
where
lhsLenght = significance lhs
rhsLenght = significance rhs
comp :: Vector Int32 -> Vector Int32 -> Ordering
comp ls rs
| V.null ls || V.null rs = EQ
| lb > rb = GT
| lb < rb = LT
| otherwise = comp (V.init ls) (V.init rs)
where
lb = fromIntegral (V.last ls) .&. longMask
rb = fromIntegral (V.last rs) .&. longMask
significance :: (Eq a, Bits a, Num a) => Vector a -> Int
significance xs
| V.null xs = 0
| V.last xs == 0 = significance ( V.init xs )
| otherwise = V.length xs
shift :: Vector Int32 -> Int -> Vector Int32
shift lhs rhs
| outWordsCount <= 0 = singleton 0
| shiftBits == 0 && rhs > 0 = V.take shiftWords r0 V.++ V.take (outWordsCount shiftWords) lhs
| rhs > 0 =
let (res, carry) = foldl shRight (r0, 0) [0 .. inWordsCount 1]
in if inWordsCount 1 + shiftWords < outWordsCount
then res & ix ( inWordsCount + shiftWords ) .~ (res ! (inWordsCount + shiftWords) .|. carry)
else res
| shiftBits == 0 = error "3"
| otherwise =
let carry = if outWordsCount + shiftWords < inWordsCount
then (lhs ! (outWordsCount + shiftWords)) `shiftL` ( intSize shiftBits)
else 0
in fst $ foldl shLeft (r0, carry) [inWordsCount 1, inWordsCount 2 .. 0]
where
shiftBits, shiftWords, inBitsCount, inWordsCount, outBitsCount, outWordsCount :: Int
shiftBits = abs rhs `mod` intSize
shiftWords = abs rhs `div` intSize
inBitsCount = getBitsCount lhs
inWordsCount = inBitsCount `div` intSize + (if inBitsCount `mod` intSize > 0 then 1 else 0)
outBitsCount = inBitsCount + rhs
outWordsCount = outBitsCount `div` intSize + (if outBitsCount `mod` intSize > 0 then 1 else 0)
r0 = V.replicate (max inWordsCount outWordsCount) 0
shRight, shLeft :: (Vector Int32, Int32) -> Int -> (Vector Int32, Int32)
shRight (res, carry) pos = ( res & ix ( pos + shiftWords ) .~ val, nextCarry )
where
temp = lhs ! pos
val = ( temp `shiftL` shiftBits ) .|. carry
nextCarry = temp `logicalShiftRight` ( intSize shiftBits )
shLeft (res, carry) pos = ( res & ix ( pos + shiftWords ) .~ val, nextCarry )
where
temp = lhs ! (pos + shiftWords)
val = (temp `logicalShiftRight` shiftBits) .|. carry
nextCarry = temp `shiftL` ( intSize shiftBits )
shiftRight :: Vector Int32 -> Vector Int32
shiftRight value = fst $ foldl right (value, 0) [len1, len2..0]
where
len = significance value
right :: (Vector Int32, Int64) -> Int -> (Vector Int32, Int64)
right (v, carry) pos = ( v & ix pos .~ fromIntegral val, nextCarry )
where
temp, nextCarry, val :: Int64
temp = fromIntegral ( v ! pos) .&. longMask
nextCarry = (temp .&. 1) `shiftL` ( intSize 1) .&. longMask
val = ((temp `logicalShiftR` 1) .|. carry ) .&. longMask
sub :: Vector Int32 -> Vector Int32 -> Vector Int32
sub lhs rhs
| lhsLength < rhsLength = error "Difference should not be negative."
| otherwise = modulo $ rest subscribed
where
lhsLength = significance lhs
rhsLength = significance rhs
modulo :: (Vector Int32, Int32) -> Vector Int32
modulo (_, 1) = error "Difference should not be negative."
modulo (l, _) = l
subscribed :: (Vector Int32, Int32)
subscribed = foldl substr (lhs, 0) [0..rhsLength 1]
where
substr :: (Vector Int32, Int32) -> Int -> (Vector Int32, Int32)
substr (l, borrow) pos = ( l & ix pos .~ fromIntegral temp, nBorrow )
where
temp = (fromIntegral ( l ! pos ) .&. longMask )
(fromIntegral ( rhs ! pos ) .&. longMask )
fromIntegral borrow
nBorrow = if temp .&. ( 1 `shiftL` intSize ) /= 0 then 1 else 0
rest :: (Vector Int32, Int32) -> (Vector Int32, Int32)
rest (ls, b) = foldl substr (ls, b) [rhsLength..lhsLength 1]
where
substr :: (Vector Int32, Int32) -> Int -> (Vector Int32, Int32)
substr (l, borrow) pos = ( l & ix pos .~ fromIntegral temp, nBorrow )
where
temp = (fromIntegral ( l ! pos ) .&. longMask ) fromIntegral borrow
nBorrow = if temp .&. ( 1 `shiftL` intSize ) /= 0 then 1 else 0
remainder :: Vector Int32 -> Vector Int32 -> Vector Int32
remainder lhs rhs = divide lhs rhs
where
rhsBitsCount = getBitsCount rhs
divide :: Vector Int32 -> Vector Int32 -> Vector Int32
divide l r
| LT == compareLists l r = l
| lhsBitsCount == 0 = l
| otherwise =
let temp' = if compareLists l temp == LT then shiftRight temp else temp
in divide ( subs l temp' ) r
where
lhsBitsCount = getBitsCount l
temp = shift r (lhsBitsCount rhsBitsCount)
subs :: Vector Int32 -> Vector Int32 -> Vector Int32
subs l t =
if compareLists l t /= LT
then subs (sub l t) t
else l
resize :: Vector Int32 -> Int -> Vector Int32
resize v l
| l < 0 = error "Invalid value for length"
| vLength < l = v V.++ V.replicate (l vLength) 0
| otherwise = V.take l v
where vLength = V.length v
normalize :: Vector Int32 -> Vector Int32
normalize x = resize x ( significance x )