module Data.Mnemonic.Electrum
	( encode
	, decode
	, encode'
	, decode'
	, Hex
	, encodeHex
	, decodeHex
	, encodeHex'
	, decodeHex'
	) where
import Data.Char
import Data.List
import Data.Maybe
import Text.Printf
import Control.Applicative
import Prelude
import Data.Mnemonic.Electrum.Types
import Data.Mnemonic.Electrum.WordList.Poetic
encode :: Integer -> String
encode = encode' poeticWords
encode' :: WordList -> Integer -> String
encode' wl = encodeChunks wl . chunkInteger
encodeHex :: Hex -> String
encodeHex = encodeHex' poeticWords
encodeHex' :: WordList -> Hex -> String
encodeHex' wl = encodeChunks wl . chunkHex
encodeChunks :: WordList -> [Int] -> String
encodeChunks (WordList len wl) = unwords .  map (wl !!) . go []
  where
  	go c [] = reverse c
	go c (x:xs) =
		let w1 = modwl x
		    w2 = modwl $ (divwl x) + w1
		    w3 = modwl $ (divwl (divwl x)) + w2
		in go (w3:w2:w1:c) xs
	modwl n = n `mod` len
	divwl n = n `div` len
chunkInteger :: Integer -> [Int]
chunkInteger 0 = [0]
chunkInteger i = go [] i
  where
	go coll n
		| n > 0 = 
			let (rest, c) = n `divMod` chunkSize
			in go (fromIntegral c:coll) rest
		| otherwise = coll
chunkSize :: Integer
chunkSize = 0x100000000
chunkHex :: Hex -> [Int]
chunkHex = go []
  where
	go coll s
		| length s >= 8 =
			let (c, rest) = splitAt 8 s
			in go (readHex c:coll) rest
		| otherwise = reverse coll
decode :: String -> Maybe Integer
decode = decode' poeticWords
decode' :: WordList -> String -> Maybe Integer
decode' wl s = sum . map calc . zip pows . reverse <$> decodeChunks wl s
  where
	calc (pow, n) = fromIntegral n * pow
	pows = map (chunkSize ^) ([0..] :: [Integer])
decodeHex :: String -> Maybe Hex
decodeHex = decodeHex' poeticWords
decodeHex' :: WordList -> String -> Maybe Hex
decodeHex' wl s = concatMap showHex <$> decodeChunks wl s
decodeChunks :: WordList -> String -> Maybe [Int]
decodeChunks (WordList len wl) s
	| any isNothing l || null l = Nothing
	| otherwise = go [] (catMaybes l)
  where
	l = map (`elemIndex` wl) $ map (map toLower) $ words s
	go coll [] = Just (reverse coll)
	go coll (w1:w2:w3:ws) =
		let x = sum
			[ w1
			, len * calc w2 w1
			, len2 * calc w3 w2
			]
		in go (x:coll) ws
	go _ _ = Nothing
	len2 = len * len
	calc a b = (ab) `mod` len
type Hex = String
readHex :: Hex -> Int
readHex h = read ("0x" ++ takeWhile (`elem` "012345678ABCDEFabcdef") h)
showHex :: Int -> Hex
showHex = printf "%08x"