{-# LANGUAGE PatternGuards #-}
-----------------------------------------------------------------------------
-- |
-- Module      :  Hecc.Base
-- Copyright   :  (c) Marcel Fourné 2009
-- License     :  MIT-X11-License
-- Maintainer  :  Marcel Fourné (hecc@bitrot.dyndns.org
--
-- ECC Base algorithms & point formats
--
-----------------------------------------------------------------------------

module Codec.Encryption.ECC.Base (ECInt(), 
                                  ECP(..),
                                  EC(..),
                                  modinv, 
                                  pmul, 
                                  ison,
                                  EPa(..), 
                                  EPp(..), 
                                  EPj(..), 
                                  EPmj(..))
    where 

-- |this may change in the future if the need arises
type ECInt = Integer

-- |extended euclidean algorithm, recursive variant
eeukl :: ECInt -> ECInt -> (ECInt, ECInt, ECInt)
eeukl a 0 = (a,1,0)
eeukl a b = let (d,s,t) = eeukl b (a `mod` b)
            in (d,t,s-(div a b)*t)

-- |computing the modular inverse
modinv :: ECInt -> ECInt -> ECInt
modinv a m = let (x,y,_) = eeukl a m
             in if x == 1 
                then mod y m
                else undefined

-- |class of all Elliptic Curves
data EC = EC (ECInt, ECInt, ECInt)
        deriving (Eq)
instance Show EC where show (EC (a,b,p)) = "y^2=x^3+" ++ show a ++ "*x+" ++ show b ++ " mod " ++ show p

-- |class of all Elliptic Curve Points
class ECP a where
    -- |function returning the appropriate INF in the specific ECP-Format, for generic higher-level-algorithms
    inf :: a
    -- |generic getters
    getx :: a -> EC -> ECInt
    -- |generic getters
    gety :: a -> EC -> ECInt
    -- |add an elliptic point onto itself, base for padd a a c
    pdouble :: a -> EC -> a
    -- |add 2 elliptic points
    padd :: a -> a -> EC -> a

-- |Elliptic Point Affine coordinates
data EPa = EPa (ECInt, ECInt) 
         | Infa
           deriving (Eq)
instance Show EPa where show (EPa (a,b)) = show (a,b)
                        show Infa = "Null"
instance ECP EPa where 
    inf = Infa
    getx (EPa (x,_)) _ = x
    getx Infa _ = undefined
    gety (EPa (_,y)) _ = y
    gety Infa _ = undefined
    pdouble (EPa (x1,y1)) (EC (alpha,_,p)) = 
        let lambda = ((3*x1^(2::Int)+alpha)*(modinv (2*y1) p)) `mod` p
            x3 = (lambda^(2::Int) - 2*x1) `mod` p
            y3 = (lambda*(x1-x3)-y1) `mod` p
        in EPa (x3,y3)
    pdouble Infa _ = Infa
    padd Infa a _ = a
    padd a Infa _ = a
    padd a@(EPa (x1,y1)) b@(EPa (x2,y2)) c@(EC (_,_,p)) 
        | x1==x2,y1==(-y2) = Infa
        | a==b = pdouble a c
        | otherwise = 
            let lambda = ((y2-y1)*(modinv (x2-x1) p)) `mod` p
                x3 = (lambda^(2::Int) - x1 - x2) `mod` p
                y3 = (lambda*(x1-x3)-y1) `mod` p
            in EPa (x3,y3)

-- |Elliptic Point Projective coordinates
data EPp = EPp (ECInt, ECInt, ECInt) 
         | Infp
           deriving (Eq)
instance Show EPp where show (EPp (a,b,c)) = show (a,b,c)
                        show Infp = "Null"
instance ECP EPp where
    inf = Infp
    getx (EPp (x,_,z)) (EC (_,_,p))= (x * (modinv z p)) `mod` p
    getx Infp _ = undefined
    gety (EPp (_,y,z)) (EC (_,_,p))= (y * (modinv z p)) `mod` p
    gety Infp _ = undefined
    pdouble (EPp (x1,y1,z1)) (EC (alpha,_,p)) = 
        let a = (alpha*z1^(2::Int)+3*x1^(2::Int)) `mod` p
            b = (y1*z1) `mod` p
            c = (x1*y1*b) `mod` p
            d = (a^(2::Int)-8*c) `mod` p
            x3 = (2*b*d) `mod` p
            y3 = (a*(4*c-d)-8*y1^(2::Int)*b^(2::Int)) `mod` p
            z3 = (8*b^(3::Int)) `mod` p
        in EPp (x3,y3,z3)
    pdouble Infp _ = Infp
    padd Infp a _ = a
    padd a Infp _ = a
    padd p1@(EPp (x1,y1,z1)) p2@(EPp (x2,y2,z2)) curve@(EC (_,_,p)) 
        | x1==x2,y1==(-y2) = Infp
        | p1==p2 = pdouble p1 curve
        | otherwise = 
            let a = (y2*z1 - y1*z2) `mod` p
                b = (x2*z1 - x1*z2) `mod` p
                c = (a^(2::Int)*z1*z2 - b^(3::Int) - 2*b^(2::Int)*x1*z2) `mod` p
                x3 = (b*c) `mod` p
                y3 = (a*(b^(2::Int)*x1*z2-c)-b^(3::Int)*y1*z2) `mod` p
                z3 = (b^(3::Int)*z1*z2) `mod` p
            in EPp (x3,y3,z3)
    
-- |Elliptic Point Jacobian coordinates
data EPj = EPj (ECInt, ECInt, ECInt) 
         | Infj
           deriving (Eq)
instance Show EPj where show (EPj (a,b,c)) = show (a,b,c)
                        show Infj = "Null"
instance ECP EPj where
    inf = Infj
    getx (EPj (x,_,z)) (EC (_,_,p)) = (x * (modinv (z^(2::Int)) p)) `mod` p
    getx Infj _ = undefined
    gety (EPj (_,y,z)) (EC (_,_,p)) = (y * (modinv (z^(3::Int)) p)) `mod` p
    gety Infj _ = undefined
    pdouble (EPj (x1,y1,z1)) (EC (alpha,_,p)) = 
        let a = 4*x1*y1^(2::Int) `mod` p
            b = (3*x1^(2::Int) + alpha*z1^(4::Int)) `mod` p
            x3 = (-2*a + b^(2::Int)) `mod` p
            y3 = (-8*y1^(4::Int) + b*(a-x3)) `mod` p
            z3 = 2*y1*z1 `mod` p
        in EPj (x3,y3,z3)
    pdouble Infj _ = Infj
    padd Infj a _ = a
    padd a Infj _ = a 
    padd p1@(EPj (x1,y1,z1)) p2@(EPj (x2,y2,z2)) curve@(EC (_,_,p)) 
        | x1==x2,y1==(-y2) = Infj
        | p1==p2 = pdouble p1 curve
        | otherwise = 
            let a = (x1*z2^(2::Int)) `mod` p
                b = (x2*z1^(2::Int)) `mod` p
                c = (y1*z2^(3::Int)) `mod` p
                d = (y2*z1^(3::Int)) `mod` p
                e = (b - a) `mod` p
                f = (d - c) `mod` p
                x3 = (-e^(3::Int) - 2*a*e^(2::Int) + f^(2::Int)) `mod` p
                y3 = (-c*e^(3::Int) + f*(a*e^(2::Int) - x3)) `mod` p
                z3 = (z1*z2*e) `mod` p
            in EPj (x3,y3,z3)

-- |Elliptic Point Modified Jacobian coordinates
data EPmj = EPmj (ECInt, ECInt, ECInt, ECInt) 
         | Infmj
           deriving (Eq)
instance Show EPmj where show (EPmj (a,b,c,d)) = show (a,b,c,d)
                         show Infmj = "Null"
instance ECP EPmj where
    inf = Infmj
    getx (EPmj (x,_,z,_)) (EC (_,_,p)) = (x * (modinv (z^(2::Int)) p)) `mod` p
    getx Infmj _ = undefined
    gety (EPmj (_,y,z,_)) (EC (_,_,p)) = (y * (modinv (z^(3::Int)) p)) `mod` p
    gety Infmj _ = undefined
    pdouble (EPmj (x1,y1,z1,z1')) (EC (_,_,p)) = 
        let s = 4*x1*y1^(2::Int) `mod` p
            u = 8*y1^(4::Int) `mod` p
            m = (3*x1^(2::Int) + z1') `mod` p
            t = (-2*s + m^(2::Int)) `mod` p
            x3 = t
            y3 = (m*(s - t) - u) `mod` p
            z3 = 2*y1*z1 `mod` p
            z3' = 2*u*z1' `mod` p
        in EPmj (x3,y3,z3,z3')
    pdouble Infmj _ = Infmj
    padd Infmj a _ = a
    padd a Infmj _ = a 
    padd p1@(EPmj (x1,y1,z1,_)) p2@(EPmj (x2,y2,z2,_)) curve@(EC (alpha,_,p)) 
        | x1==x2,y1==(-y2) = Infmj
        | p1==p2 = pdouble p1 curve
        | otherwise = 
            let u1 = (x1*z2^(2::Int)) `mod` p
                u2 = (x2*z1^(2::Int)) `mod` p
                s1 = (y1*z2^(3::Int)) `mod` p
                s2 = (y2*z1^(3::Int)) `mod` p
                h = (u2 - u1) `mod` p
                r = (s2 - s1) `mod` p
                x3 = (-h^(3::Int) - 2*u1*h^(2::Int) + r^(2::Int)) `mod` p
                y3 = (-s1*h^(3::Int) + r*(u1*h^(2::Int) - x3)) `mod` p
                z3 = (z1*z2*h) `mod` p
                z3' = (alpha*z3^(4::Int)) `mod` p
            in EPmj (x3,y3,z3,z3')

-- |this is a generic handle for Point Multiplication. The implementation will likely change.
pmul :: (ECP a) => a -> ECInt -> EC -> a
pmul = dnadd

-- |double and add for generic ECP
dnadd :: (ECP a) => a -> ECInt -> EC -> a
dnadd b k' c@(EC (_,_,p)) = 
    let k'' = k' `mod` (p - 1)
        ex a k s
            | k == 0 = s
            | k `mod` 2 == 0 = ex (pdouble a c) (k `div` 2) s
            | otherwise = ex (pdouble a c) (k `div` 2) (padd a s c)
    in ex b k'' inf

-- |generic verify, if generic ECP is on EC via getx and gety
ison :: (ECP a) => a -> EC -> Bool
ison pt curve@(EC (alpha,beta,p)) = let x = getx pt curve
                                        y = gety pt curve
                                    in (y^(2::Int)) `mod` p == (x^(3::Int)+alpha*x+beta) `mod` p