-----------------------------------------------------------------------------
-- |
-- Module      :  Data.F2
-- Copyright   :  (c) Marcel Fourné 2011-2013
-- License     :  BSD3
-- Maintainer  :  Marcel Fourné (mail@marcelfourne.de)
-- Stability   :  experimental
-- Portability :  Good
--
-- A hopefully timing attack resistant F(2^e) backend, 
-- i.e. binary representation polynomial arithmetic
-- The presented interface is Big Endian, like Data.Bits
-- All indices are [0 .. (length - 1)]
-- TODO: further optimization
--
-----------------------------------------------------------------------------

{-# OPTIONS_GHC -O2 -fllvm -optlo-O3 -feager-blackholing #-}
{-# LANGUAGE BangPatterns #-}

module Data.F2 (
  F2,
  pow,
  toInteger,
  even,
  odd,
  mod,
  div,
  bininv
  ) where

import Prelude hiding ((^),fromInteger,toInteger,even,odd,div,mod)
import qualified Prelude as P ((^),fromInteger,toInteger,even,odd,div,mod)
import qualified Numeric as N (showIntAtBase)
import qualified Data.Char as C (intToDigit)
import qualified Data.Bits as B (Bits,bit,bitSize,complement,isSigned,popCount,rotate,shift,testBit,xor,(.|.),(.&.))
import qualified Data.Vector.Unboxed as V (Vector,reverse,dropWhile,length,singleton,last,mapM_,fromList,replicate,(++),zipWith,head,tail,take,map,drop,(!),foldl')
import qualified Data.Word as W (Word)
import Data.Serialize as S (Serialize,put,get)
import Control.Monad (replicateM)

-- internal helper definitions
wordMaxF2 :: Integer
wordMaxF2 = P.toInteger (maxBound::W.Word)
wordSizeF2 :: Int
wordSizeF2 = B.bitSize (0::W.Word)

-- | the binary representation of an Integer, as a list
binary :: (Integral a, Show a) => a -> [Char]
binary = flip (N.showIntAtBase 2 C.intToDigit) []

-- | helper function to shorten unnecesary long representations
shorten :: F2 -> F2
shorten !(F2 _ !va) = let vn = V.reverse $ V.dropWhile (== 0) $ V.reverse va
                          vnew = if V.length vn == 0 
                                 then V.singleton 0
                                 else vn
                          indexnew i | i >= 0 = if B.testBit (V.last vnew) i == True 
                                                then i + ((V.length vnew) - 1) * wordSizeF2 + 1
                                                else indexnew (i - 1)
                                     | otherwise = 1
                      in F2 (indexnew wordSizeF2) vnew

-- Be careful with those indices! The usage of quotRem with them has caused some headache.
-- | F2 consist of an exact length of meaningful bits an a representation of those bits in a possibly larger Vector of Words, Note: The vectors use small to large indices, but the Data.Word endianness is of no concern as it is hidden by Data.Bits
data F2 = F2 {-# UNPACK #-} !Int !(V.Vector W.Word) 
        deriving (Show) -- Note: Ord, Enum do not make sense for F2, but perhaps by translation to Integer, what about Real?

-- TODO: more thought on (==), timing-attack resistance
instance Eq F2 where
  (==) !a !b = (toInteger a) == (toInteger b)

instance Serialize F2 where
  put !a = let (F2 l v) = shorten a
           in S.put l >> V.mapM_ (S.put) v
  get = do
    l <- S.get
    let len = let (d,r) = quotRem l wordSizeF2
              in if r == 0 then d else d + 1
    w <- replicateM len (S.get)
    return $ (F2) l $ V.fromList w

instance Num F2 where 
  (+) = B.xor
  -- The timing attack resistance for (*) looks brittle, needs more careful thought
  (*) !a@(F2 !la !va) !b@(F2 !lb !vb) = 
    let vl1 = V.length va
        vl2 = V.length vb
        nullen = F2 la $ V.replicate vl1 0
        pseudo = F2 lb $ V.replicate vl2 0
        fun i b1 | i < la = if B.testBit a i
                            -- real branch
                            then fun (i + 1) (b1 `B.xor` (B.shift b i))
                            -- for timing-attack-resistance xor with 0s
                            else fun (i + 1) (b1 `B.xor` (B.shift pseudo i))
                 | otherwise = b1
    in fun 0 nullen
  -- always abs
  abs !a = a
  -- always unsigned
  signum _ = 1
  fromInteger !i = 
    if i >= 0 
    then let bin = binary i
             helper a = 
               if a <= wordMaxF2 then V.singleton $ P.fromInteger a
               else let (d,rest) = quotRem a (wordMaxF2 + 1)
                    in  (V.singleton $ P.fromInteger rest) V.++ (helper d)
         in F2 (length bin) (helper i)
    else error "F2 are only defined for non-negative Integers"

instance B.Bits F2 where
  (.&.) !(F2 !la !va) !(F2 !lb !vb) = 
    let vl1 = V.length va
        vl2 = V.length vb
        vdiff = abs $ vl1 - vl2
    in if vl1 == vl2 then F2 (if la >= lb then la else lb) $ V.zipWith (B..&.) va vb
       else if vl1 > vl2 
            then F2 la $ V.zipWith (B..&.) va $ V.replicate vdiff 0 V.++ vb
            else F2 lb $ V.zipWith (B..&.) (V.replicate vdiff 0 V.++ va) vb
  (.|.) !(F2 !la !va) !(F2 !lb !vb) = 
    let vl1 = V.length va
        vl2 = V.length vb
        vdiff = abs $ vl1 - vl2
    in if vl1 == vl2 then F2 (if la >= lb then la else lb) $ V.zipWith (B..|.) va vb
       else if vl1 > vl2 
            then F2 la $ V.zipWith (B..|.) va $ V.replicate vdiff 0 V.++ vb
            else F2 lb $ V.zipWith (B..|.) (V.replicate vdiff 0 V.++ va) vb
  xor !(F2 !la !va) !(F2 !lb !vb) = 
    let vl1 = V.length va
        vl2 = V.length vb
        vdiff = abs $ vl1 - vl2
    in if vl1 == vl2 then F2 (if la >= lb then la else lb) $ V.zipWith (B.xor) va vb
       else if vl1 > vl2 
            then F2 la $ V.zipWith (B.xor) va $ V.replicate vdiff 0 V.++ vb
            else F2 lb $ V.zipWith (B.xor) (V.replicate vdiff 0 V.++ va) vb
  complement !(F2 !la !va) = F2 la $ V.map (B.complement) va
  -- Big Endian on Words! Machine Endianness should not be important, Data.Bits handles it.
  -- The timing attack resistance for shift looks fishy at best! 
  -- Prime target for optimization
  shift !a@(F2 !la !va) !i = 
    if i == 0 then a
    else let newlen = la + i
             newlenword = let (w,r) = newlen `quotRem` (wordSizeF2)
                          in if r > 0 then w + 1 else w
             realshift = i `rem` wordSizeF2
             veclendiff = newlenword - (V.length va)
             svec = if veclendiff >= 0 
                    then if realshift > 0 
                         then V.replicate (veclendiff - 1) 0 V.++ (V.map (flip B.shift realshift) va) V.++ V.singleton 0
                         else V.replicate veclendiff 0 V.++ V.map (flip B.shift realshift) va
                    else V.drop (abs veclendiff) (V.map (flip B.shift (realshift)) va)
             svecr = if veclendiff >= 0 
                     then V.replicate veclendiff 0 V.++ V.map (flip B.shift (realshift - wordSizeF2)) va
                     else V.drop (abs veclendiff) (V.map (flip B.shift (wordSizeF2 + realshift)) va)
         in if newlen >= 1 then F2 newlen $ V.zipWith (B.xor) svec svecr
            else F2 1 $ V.singleton 0
  rotate !a !i = B.shift a i
  bitSize !(F2 !l _)= l
  isSigned _ = False
  bit !i = P.fromInteger $ 2 P.^ i
  testBit !(F2 !la !va) !i = 
    if i >= 0 
    then if i < wordSizeF2 
         then flip B.testBit i $ V.head va
         else if i < la 
              then let (index1,index2) = i `quotRem` wordSizeF2
                   in flip B.testBit index2 $ (V.!) va index1
              else False
    else False
  popCount !(F2 _ !va) = V.foldl' (+) 0 $ V.map B.popCount va

-- instance Real F2 where -- TODO?

-- instance Integral F2 where -- TODO?

-- |conversion to Integer
toInteger :: F2 -> Integer
toInteger !(F2 !la !va) = 
  if la <= wordSizeF2
  then rem (P.toInteger $ V.head va) $ 2 P.^ (P.toInteger la)
  else let len = V.length va
           helper r z i = 
             if i > 1
             then helper (V.tail r) (z + (B.shift (P.toInteger $ V.head r) ((len - i) * wordSizeF2))) (i - 1)
             else z + (B.shift (P.toInteger $ V.head r) ((len - i) * wordSizeF2))
       in helper va 0 len

-- | Polynomial reduction, a.k.a. modulo on polynomials
mod :: F2 -- ^ a
       -> F2 -- ^ b
       -> F2 -- ^ a `mod` b
mod !a@(F2 !la _) !b@(F2 !lb !vb) 
  | b == (P.fromInteger 0) = a
  | b == (P.fromInteger 1) = P.fromInteger 0
  | otherwise = let lbv = V.length vb
                    pseudo = F2 lbv $ V.replicate lbv 0
                    fun !z@(F2 _ !v) i | i >= lb = if B.testBit z (i - 1)
                                                   -- real branch
                                                   then fun (z + (B.shift b (i - lb))) (i - 1)
                                                   -- for timing-attack-resistance xor with 0s
                                                   else fun (z + (B.shift pseudo (i - lb))) (i - 1)
                                       | otherwise = F2 i $ V.take ((i `quot` wordSizeF2) + 1) v -- shortening
                in fun a $ la

-- |The power function on F2
pow :: F2 -- ^ a
       -> Integer -- ^ k
       -> F2 -- ^ a^k
pow !a !k | k < 0 = error "negative exponent for the power function on F2"
          | k == 0 = P.fromInteger 1
          | k == 1 = a
          | k == 2 = a * a
          | k == 3 = a * a * a
          | otherwise = let power2 z = z * z
                            ex p1 p2 i
                              | i < 0 = p1
                              | B.testBit k i == False = ex (power2 p1) (p1 * p2) (i - 1)
                              | otherwise = ex (p1 * p2) (power2 p2) (i - 1)
                        in ex a (power2 a) ((length $ binary k) - 2)

-- | O(1), a simple Test for the LSB
even :: F2 -> Bool
even !(F2 _ !v) = B.testBit (V.head v) 0 == False

-- | O(1), a simple Test for the LSB
odd :: F2 -> Bool
odd !(F2 _ !v) = B.testBit (V.head v) 0 == True

-- | Polynomial division, needs 3 parameters instead of 2, computing k/f mod m by binary inversion of f in m
div :: F2 -- ^ k
       -> F2 -- ^ f
       -> F2 -- ^ m
       -> F2 -- ^ k/f `mod` m
div !k !f !m = ((*) k $ bininv f m) `mod` m
                         
-- | binary inversion of f in m
bininv :: F2 -- ^ f
          -> F2 -- ^ m
          -> F2 -- the binary inverse of f in m
bininv !f !m = let helper :: F2 -> F2 -> F2 -> F2 -> F2
                   helper !u@(F2 lu _) !v@(F2 lv _) !g1 !g2 
                     | u == (P.fromInteger 1) = g1
                     | otherwise = let j = (lu) - (lv)
                                   in if j < 0 
                                      then helper (shorten $ v + (B.shift u (-j))) u (shorten $ g2 + (B.shift g1 (-j))) g1
                                      else helper (shorten $ u + (B.shift v j))    v (shorten $ g1 + (B.shift g2 j))    g2
               in helper f m (P.fromInteger 1) (P.fromInteger 0)