module Data.Crypto (generatekey, encrypt, decrypt) where

import Prelude hiding (round)
import Data.Word (Word8)
import Data.Bits ((.&.))

generatekey :: Word -> Key
generatekey w
 | w `elem` [0..1023] = bin10 w
 | otherwise = bin10 (w `mod` 1024)

encrypt :: Key -> Word8 -> Word8
encrypt k w = let
    (sk1,sk2) = keyschedule k
    pt        = bin w
    ct        =
      (ipinverse . round sk2 . sw . round sk1 . split . ip) pt
 in
    dec ct

decrypt :: Key -> Word8 -> Word8
decrypt k w = let
    (sk1,sk2) = keyschedule k
    ct        = bin w
    pt        =
      (ipinverse . round sk1 . sw . round sk2 . split . ip) ct
 in
    dec pt

keyschedule :: Key -> (Subkey,Subkey)
keyschedule k = let
    key' :: Subkey
    key' = p10 k

    s1,s2 :: Subkey
    (s1,s2) = split key'

    s1',s2' :: Subkey
    s1' = ls1 s1
    s2' = ls1 s2

    sk1 :: Subkey
    sk1 = p8 (s1',s2')

    s1'',s2'' :: Subkey
    s1'' = ls2 s1'
    s2'' = ls2 s2'

    sk2 :: Subkey
    sk2 = p8 (s1'',s2'')
 in
    (sk1,sk2)

round :: Subkey -> (Block,Block) -> (Block,Block)
round subkey (l,r) = let
    r' :: Block
    r' = expansion r

    r'' :: Block
    r'' = r' `xor8` subkey

    left,right :: Block
    (left,right) = split r''

    left',right' :: Block
    left'  = sbox0 left
    right' = sbox1 right

    merged :: Block
    merged = p4 (left',right')

    fk :: Block
    fk = merged `xor4` l
 in
    (fk,r)

data Bit = Zero | One
 deriving stock Show
data Block =
    B2 Bit Bit
    | B4 Bit Bit Bit Bit
    | B5 Bit Bit Bit Bit Bit
    | B8 Bit Bit Bit Bit Bit Bit Bit Bit
    | B10 Bit Bit Bit Bit Bit Bit Bit Bit Bit Bit
 deriving stock Show

type Key          = Block
type Subkey       = Block
type Permutation  = Block -> Block
type Split        = Block -> (Block,Block)
type Rotation     = Block -> Block
type P4           = (Block,Block) -> Block
type P8           = (Block,Block) -> Block
type Xorbit       = (Bit,Bit) -> Bit
type Xor         = Block -> Block -> Block
type Substitution = Block -> Block
type IPinverse    = (Block,Block) -> Block
type Mask         = Word8
type Mask10       = Word
type Switch       = (Block,Block) -> (Block,Block)

bin :: Word8 -> Block
bin w = let
    b1,b2',b3,b4',b5',b6,b7,b8' :: Bit
    bits :: Mask -> Bit
    bits m = if w .&. m == m then One else Zero

    b8' = bits 0x01
    b7  = bits 0x02
    b6  = bits 0x04
    b5' = bits 0x08
    b4' = bits 0x10
    b3  = bits 0x20
    b2' = bits 0x40
    b1  = bits 0x80
    ret = B8 b1 b2' b3 b4' b5' b6 b7 b8'
 in
    ret

bin10 :: Word -> Block
bin10 w
 | w `elem` [0..1023] = let 
    b1,b2',b3,b4',b5',b6,b7,b8',b9,b10' :: Bit
    bits :: Mask10 -> Bit
    bits m = if w .&. m == m then One else Zero

    b10' = bits 0x01
    b9   = bits 0x02
    b8'  = bits 0x04
    b7   = bits 0x08
    b6   = bits 0x10
    b5'  = bits 0x20
    b4'  = bits 0x40
    b3   = bits 0x80
    b2'  = bits 0x100
    b1   = bits 0x200
    ret  = B10 b1 b2' b3 b4' b5' b6 b7 b8' b9 b10'
 in
    ret
bin10 _ = error "Out of bounds"

dec :: Block -> Word8
dec (B8 b1 b2' b3 b4' b5' b6 b7 b8') = let
    bits :: Bit -> Mask -> Word8
    bits Zero _ = 0
    bits One m  = m
 in
    bits b8' 0x01
  + bits b7  0x02
  + bits b6  0x04
  + bits b5' 0x08
  + bits b4' 0x10
  + bits b3  0x20
  + bits b2' 0x40
  + bits b1  0x80
dec _ = undefined

dec10 :: Block -> Word
dec10 (B10 b1 b2' b3 b4' b5' b6 b7 b8' b9 b10') = let
    bits :: Bit -> Mask10 -> Word
    bits Zero _ = 0
    bits One m  = m
 in
    bits b10' 0x01
  + bits b9   0x02
  + bits b8'  0x04
  + bits b7   0x08
  + bits b6   0x10
  + bits b5'  0x20
  + bits b4'  0x40
  + bits b3   0x80
  + bits b2'  0x100
  + bits b1   0x200
dec10 _ = undefined

bit :: Int -> Bit
bit 0 = Zero
bit _ = One
dig :: Bit -> Int
dig Zero = 0
dig One = 1

b2,b4,b5,b8,b10  :: Block

b2               = B2 Zero Zero
b4               = B4  Zero Zero Zero Zero
b5               = B5  Zero Zero Zero Zero Zero
b8               = B8  Zero Zero Zero Zero Zero Zero Zero Zero
b10              = B10 Zero Zero Zero Zero Zero Zero Zero Zero Zero Zero

p10 :: Permutation
p10 (B10 k1 k2 k3 k4 k5 k6 k7 k8 k9 k10) =
     B10 k3 k5 k2 k7 k4 k10 k1 k9 k8 k6
p10 _ = undefined

split :: Split
split (B10 k1 k2 k3 k4 k5 k6 k7 k8 k9 k10) = (block1,block2)
 where
    block1,block2 :: Block
    block1 = B5 k1 k2 k3 k4 k5
    block2 = B5 k6 k7 k8 k9 k10
split (B8 b1 b2' b3 b4' b5' b6 b7 b8') = (block1,block2)
 where
    block1,block2 :: Block
    block1 = B4 b1  b2' b3 b4'
    block2 = B4 b5' b6  b7 b8'
split _ = undefined

sw :: Switch
sw (b1,b2') = (b2',b1)

ls1,ls2 :: Rotation
ls1 (B5 k1 k2 k3 k4 k5) =
    B5 k2 k3 k4 k5 k1
ls1 _ = undefined
ls2 (B5 k1 k2 k3 k4 k5) =
    B5 k3 k4 k5 k1 k2
ls2 _ = undefined

p8 :: P8
p8 (B5 _ _ k3 k4 k5,B5 k6 k7 k8 k9 k10) =
    B8 k6 k3 k7 k4 k8 k5 k10 k9
p8 _ = undefined

ip :: Permutation
ip (B8 b1 b2' b3 b4' b5' b6 b7 b8') =
    B8 b2' b6 b3 b1 b4' b8' b5' b7
ip _ = undefined

ipinverse :: IPinverse
ipinverse (B4 b1 b2' b3 b4',B4 b5' b6 b7 b8') =
    B8 b4' b1 b3 b5' b7 b2' b8' b6
ipinverse _ = undefined

xor :: Xorbit
xor (Zero,Zero) = Zero
xor (Zero,One) = One
xor (One,Zero) = One
xor (One,One) = Zero

xor4 :: Xor
xor4 (B4 b1 b2' b3 b4') (B4 k1 k2 k3 k4) =
    B4 o1 o2 o3 o4
 where
    o1,o2,o3,o4 :: Bit
    o1 = xor (b1,k1)
    o2 = xor (b2',k2)
    o3 = xor (b3,k3)
    o4 = xor (b4',k4)
xor4 _ _ = undefined

xor8 :: Xor
xor8 (B8 b1 b2' b3 b4' b5' b6 b7 b8')
    (B8 k1 k2 k3 k4 k5 k6 k7 k8) =
     B8 o1 o2 o3 o4 o5 o6 o7 o8
 where
    o1,o2,o3,o4,o5,o6,o7,o8 :: Bit
    o1 = xor (b1,k1)
    o2 = xor (b2',k2)
    o3 = xor (b3,k3)
    o4 = xor (b4',k4)
    o5 = xor (b5',k5)
    o6 = xor (b6,k6)
    o7 = xor (b7,k7)
    o8 = xor (b8',k8)
xor8 _ _ = undefined

sbox0,sbox1 :: Substitution

sbox0 (B4 b1 b2' b3 b4') = let
    row :: [Int]
    row = case (dig b1,dig b4') of
        (0,0) -> [1,0,3,2]
        (0,1) -> [3,2,1,0]
        (1,0) -> [0,2,1,3]
        (1,1) -> [3,1,3,2]
        _     -> undefined
    col :: Int
    col = case (dig b2',dig b3) of
        (0,0) -> row !! 0
        (0,1) -> row !! 1
        (1,0) -> row !! 2
        (1,1) -> row !! 3
        _     -> undefined
    ret :: Block
    ret = case col of
        0 -> B2 Zero Zero
        1 -> B2 Zero One
        2 -> B2 One Zero
        3 -> B2 One One
        _ -> undefined
 in
    ret
sbox0 _ = undefined

sbox1 (B4 b1 b2' b3 b4') = let
    row :: [Int]
    row = case (dig b1,dig b4') of
        (0,0) -> [0,1,2,3]
        (0,1) -> [2,0,1,3]
        (1,0) -> [3,0,1,0]
        (1,1) -> [2,1,0,3]
        _     -> undefined
    col :: Int
    col = case (dig b2',dig b3) of
        (0,0) -> row !! 0
        (0,1) -> row !! 1
        (1,0) -> row !! 2
        (1,1) -> row !! 3
        _     -> undefined
    ret :: Block
    ret = case col of
        0 -> B2 Zero Zero
        1 -> B2 Zero One
        2 -> B2 One Zero
        3 -> B2 One One
        _ -> undefined
 in
    ret
sbox1 _ = undefined

expansion :: Permutation
expansion (B4 b1 b2' b3 b4') =
    B8 b4' b1 b2' b3 b2' b3 b4' b1
expansion _ = undefined

p4 :: P4
p4 (B2 b1 b2',B2 b3 b4') =
    B4 b2' b4' b3 b1
p4 _ = undefined








