module Crypto.PBKDF2 (pbkdf2, pbkdf2', Password(..), Salt(..), HashedPass(..),toOctets,fromOctets ) where
import qualified Data.ByteString.Char8 as B
import qualified Data.ByteString.Lazy as L
import Data.Data (Data)
import Data.Typeable (Typeable)
import GHC.Word
import Control.Monad (foldM)
import System.Random
import Data.Digest.SHA512 (hash)
import Data.Word 
import Data.Bits
import Data.Binary
newtype Password = Password [Word8]
  deriving (Read,Show,Ord,Eq,Data,Typeable)
newtype Salt = Salt [Word8]
  deriving (Read,Show,Ord,Eq,Data,Typeable)
newtype HashedPass = HashedPass [Word8]
  deriving (Read,Show,Ord,Eq,Data,Typeable)
t = pbkdf2 (Password . toOctets $ "blee") (Salt . toOctets $ "blah")
 
pbkdf2 :: Password -> Salt -> HashedPass
pbkdf2 = pbkdf2' (prfSHA512,64) 5000 64
pbkdf2' :: ( ([Word8] -> [Word8] -> [Word8]),Integer) -> Integer -> Integer -> Password -> Salt -> HashedPass
pbkdf2' (prf,hlen) cIters dklen (Password pass) (Salt salt) 
  | dklen > ( (2^321) * hlen) = error $ "pbkdf2, (dklen,hlen) : " ++ (show (dklen,hlen))
  | otherwise = 
    let 
        l = ceiling $ (fromIntegral dklen) / (fromIntegral hlen )
        r = dklen  ( (l1) * hlen)
        ustream :: [Word8] -> [Word8] -> [[Word8]]
        ustream p s = let x = prf p s
                      in  x : ustream p x    
        
        us i = take (fromIntegral cIters) $ ustream pass ( salt `myor` ((intToFourWord8s i) ))
        
        f pass salt cIters i = foldr1 myxor $ us i
        ts :: [[Word8]]
        ts = map (f pass salt cIters) ( [1..l] )
    in HashedPass . take (fromIntegral dklen) . concat $ ts
toOctets :: (Binary a) => a -> [Word8]
toOctets x = L.unpack . encode $ x
fromOctets :: (Binary a) => [Word8] -> a
fromOctets = decode . L.pack
intToFourWord8s i = let w8s =  toOctets $ i
                    in drop (length w8s 4) w8s
myxor :: [Word8] -> [Word8] -> [Word8]
myxor = zipWith xor 
myor :: [Word8] -> [Word8] -> [Word8]
myor = zipWith (.|.)
prfSHA512 :: [Word8] -> [Word8] -> [Word8]
prfSHA512 seed pass = hash $ seed ++ pass
t2 = prfSHA512 (toOctets "asdf") (toOctets "jkl; asdfjl; asjdfnkl;ajsdfl;jk;sn")