-----------------------------------------------------------------------------
-- |
-- Module      :  Codec.Crypto.ECC.F2
-- Copyright   :  (c) Marcel Fourné 2011
-- License     :  BSD3
-- Maintainer  :  Marcel Fourné (hecc@bitrot.dyndns.org)
--
-- F(2^e)-Backend
--
-----------------------------------------------------------------------------

{-# LANGUAGE TypeOperators,FlexibleContexts,FlexibleInstances #-}
module Codec.Crypto.ECC.F2 (f2eAdd,
                            f2eMul,
                            f2eBitshift,
                            f2eReduceBy,
                            f2eFromInteger,
                            f2ePow,
                            f2eToInteger,
                            f2eTestBit,
                            elimFalses,
                            modinvF2,
                            f2eLen)
       where

import Data.List as L
import Numeric
import Data.Char
import Data.Array.Repa as R
import qualified Data.Vector.Unboxed as V

instance Eq a => Eq (Array U DIM1 a) where
{-c@(R.Array r1 sh1 a1) == c'@(R.Array r2 sh2 a2) = let l1 = V.length $ toUnboxed c
                                                      i1 = index c
                                                      i2 = index c'
                                                  in foldAllP (and) True $ traverse2 
                                                     r1 
                                                     (\(sh1 :. l1) -> (sh1 :. l1)) 
                                                     (\equals i1 i2 sh3 -> 
                                                       if i1 sh3 Prelude.== i2 sh3 
                                                       then True
                                                       else False)-}
c == c' = c Prelude.== c'

bxor :: Bool -> Bool -> Bool
bxor a b | a Prelude.== False = b
         | a Prelude.== True = not b
         | otherwise = undefined
 
-- hier optimieren per C
-- |binary addition of @a1@ and @a2@
f2eAdd :: Array U DIM1 Bool -> Array U DIM1 Bool -> Array U DIM1 Bool
f2eAdd a1 a2 = let l1 = V.length $ toUnboxed a1
                   l2 = V.length $ toUnboxed a2
                   l = if l1 >= l2 then l1
                       else l2 
                   add' a1' a2' = R.zipWith 
                                  (bxor) 
                                  (fillTo a1' l) 
                                  (fillTo a2' l)
               in computeUnboxedP $ add' a1 a2

-- eventuell auch per C optimieren (statt parallel)
-- nötig? doch per slices und internem shift?
-- |a simple bitshift where @n@ shifts left, so a negative @n@ shifts right
f2eBitshift :: Array U DIM1 Bool -> Int -> Array U DIM1 Bool
-- f2eBitShift a 0 = a
f2eBitshift a n = let l1 = V.length $ toUnboxed a
                      in computeUnboxedP $ R.traverse
                         a
                         (\(sh :. l) -> (sh :. (l + n)))
                         (\lookie (sh:. l2) -> if l2 >= l1 
                                               then False
                                               else lookie (sh :. l2))
-- |binary multiplication of @a1@ and @a2@                         
f2eMul :: Array U DIM1 Bool -> Array U DIM1 Bool -> Array U DIM1 Bool
f2eMul a1 a2 = let l1 = V.length $ toUnboxed a1
                   l2 = V.length $ toUnboxed a2
                   l = if l1 >= l2 then l1
                       else l2
                   lz = (2*l) - 1
                   nullen = R.fromUnboxed (Z :. lz) $ V.replicate lz False
                   pseudo = R.fromUnboxed (Z :. l2) $ V.replicate l2 False
                   fun a b | not $ V.null a = let ltemp = (V.length a) - 1
                                              in if V.head a Prelude.== True 
                                                      -- real branch
                                                 then fun (V.tail a) (f2eAdd b (fillTo (f2eBitshift a2 ltemp) lz))
                                                      -- for timing-attack-resistance xor with 0s
                                                 else fun (V.tail a) (f2eAdd b (fillTo (f2eBitshift pseudo ltemp) lz))
                           | otherwise = b
               in elimFalses $ fun (toUnboxed $ fillTo a1 l) nullen

-- |polynomial reduction of @a@ via @r@
f2eReduceBy :: Array U DIM1 Bool -> Array U DIM1 Bool -> Array U DIM1 Bool
f2eReduceBy a r | (f2eLen r Prelude.== 1) && (f2eToInteger r Prelude.== 1) = f2eFromInteger 0
                | (f2eLen r  Prelude.== 1) && (f2eToInteger r Prelude.== 0) = a
                | otherwise = 
                  let va = toUnboxed a
                      lr = V.length $ toUnboxed r
                      pseudo = R.fromUnboxed (Z :. lr) $ V.replicate lr False
                      fun z 
                        | V.length z >= lr = 
                          let ltemp = V.length z
                          in if V.head z Prelude.== True 
                                  -- real branch
                             then fun (V.tail (V.zipWith (bxor) z (toUnboxed $ fillTo (f2eBitshift r (ltemp-lr)) ltemp)))
                                  -- for timing-attack-resistance xor with 0s
                             else fun (V.tail (V.zipWith (bxor) z (toUnboxed $ fillTo (f2eBitshift pseudo (ltemp-lr)) ltemp)))
                        | otherwise = z
                      ergtemp = fun va                      
                      pre = fromUnboxed (Z :. (V.length) ergtemp) ergtemp
                  in elimFalses pre

-- too much overhead, unroll for the only cases used: k = 2 and k = 3
f2ePow :: Array U DIM1 Bool -> Integer -> Array U DIM1 Bool
{-f2ePow b k =
  let zwo = (f2eFromInteger 2)
      ex p1 p2 i
        | i < 0 = p1
        | not (testBit k i) = ex (f2eMul p1 zwo) (f2eAdd p1 p2) (i - 1)
        | otherwise = ex (f2eAdd p1 p2) (f2eMul p2 zwo) (i - 1)
  in ex b (f2eMul b zwo) ((L.length (binary k)) - 2)-}
f2ePow b k | k Prelude.== 2 = f2eMul b b
           | k Prelude.== 3 = f2eMul b $ f2eMul b b
           | otherwise = b



fillTo :: Array U DIM1 Bool -> Int -> Array U DIM1 Bool
fillTo a n = let vec = toUnboxed a
                 l = V.length vec
             in if l < n 
                then fromUnboxed (Z :. n) $ (V.replicate (n-l) False) V.++ vec
                else a

shortenTo :: Array U DIM1 Bool -> Int -> Array U DIM1 Bool
shortenTo a n = let vec = toUnboxed a
                    l = V.length vec
                    n' = abs n
                in fromUnboxed (Z :. n') $ V.drop (l - n') vec
                   
elimFalses :: Array U DIM1 Bool -> Array U DIM1 Bool
elimFalses a = let v = toUnboxed a
                   i = V.length v
                   helper n = if n <= 1 then 1
                              else if f2eTestBit a (i - n) Prelude.== False then helper (n - 1)
                                   else n
               in shortenTo a (helper i)

binary :: Integer -> String
binary = flip (showIntAtBase (2::Integer) intToDigit) []

f2eFromInteger :: Integer -> Array U DIM1 Bool
f2eFromInteger z = let helper a = if a Prelude.== '1' then True
                                  else False
                       bin = binary z
                       len = length bin
                   in fromListUnboxed (Z :. len) $ L.map helper bin
                      
f2eToInteger :: Array U DIM1 Bool -> Integer
f2eToInteger z = let helper a = if a Prelude.== True then 1
                                else 0
                     vec = toUnboxed z
                     it rest n = let len = V.length rest
                                 in if len > 0 then let el = V.head rest
                                                    in it (V.tail rest) (n + (helper el)*2^(len-1))
                                    else n
                 in it vec 0

f2eTestBit :: Array U DIM1 Bool -> Int -> Bool
f2eTestBit k i = let l = V.length $ toUnboxed k
                 in if i >= 0 && l >= 0 && i <= l then index k (Z :. i)
                 else undefined

-- |computing the modular inverse of @a@ `emod` @m@, this is broken atm
modinvF2 :: Array U DIM1 Bool -- ^the polynomial to invert
            -> Array U DIM1 Bool -- ^the modulus
            -> Array U DIM1 Bool -- ^the inverted value
modinvF2 a f = let helper u v g1 g2 
                     | ((V.length $ toUnboxed u) Prelude.== 1) && (u Codec.Crypto.ECC.F2.== f2eFromInteger 1) = g1
                     | otherwise = 
                         let j = (V.length $ toUnboxed u) - (V.length $ toUnboxed v)
                         in if j < 0 then helper (elimFalses (v `f2eAdd` (f2eBitshift u (-j)))) u (elimFalses (g2 `f2eAdd` (f2eBitshift g1 (-j)))) g1
                            else helper (elimFalses (u `f2eAdd` (f2eBitshift v j))) v (elimFalses (g1 `f2eAdd` (f2eBitshift g2 j))) g2
               in helper a f (f2eFromInteger 1) (f2eFromInteger 0)

f2eLen a = V.length $ toUnboxed a