-----------------------------------------------------------------------------
-- |
-- Module: Data.Enumerator.Text
-- Copyright: 2010 John Millikin
-- License: MIT
--
-- Maintainer: jmillikin@gmail.com
-- Portability: portable
--
-- Enumerator-based text IO
--
-----------------------------------------------------------------------------
module Data.Enumerator.Text (
	  -- * Enumerators and iteratees
	  enumHandle
	, enumFile
	, iterHandle
	  -- * Codecs
	, Codec
	, encode
	, decode
	, utf8
	, utf16_le
	, utf16_be
	, utf32_le
	, utf32_be
	, ascii
	, iso8859_1
	) where
import Control.Monad.IO.Class (MonadIO)
import qualified Control.Exception as E
import qualified Data.Text as T
import qualified Data.Text.IO as T
import qualified System.IO as IO
import System.IO.Error (isEOFError)
import Control.Arrow (first)
import Data.Bits ((.&.))
import qualified Data.ByteString as B
import qualified Data.ByteString.Char8 as B8
import qualified Data.Text.Encoding as TE
import Data.Bits ((.|.), shiftL)
import Data.Word (Word16)
import Prelude as Prelude
import Numeric (showIntAtBase)
import Data.Char (toUpper, intToDigit, ord)
import Data.Word (Word8)
import System.IO.Unsafe (unsafePerformIO)
import Data.Enumerator
import Data.Enumerator.Util
-- | Read lines of text from the handle, and stream them to an 'Iteratee'.
-- If an exception occurs during file IO, enumeration will stop and 'Error'
-- will be returned. Exceptions from the iteratee are not caught.
--
-- The handle should be opened with an appropriate text encoding, and
-- in 'IO.ReadMode' or 'IO.ReadWriteMode'.
enumHandle :: MonadIO m => IO.Handle -> Enumerator T.Text m b
enumHandle h = Iteratee . loop where
	loop (Continue k) = withText $ \maybeText -> case maybeText of
		Nothing -> return $ Continue k
		Just text -> runIteratee (k (Chunks [text])) >>= loop
	
	loop step = return step
	withText = tryStep $ E.catch
		(Just `fmap` T.hGetLine h)
		(\err -> if isEOFError err
			then return Nothing
			else E.throwIO err)
-- | Opens a file path in text mode, and passes the handle to 'enumHandle'.
-- The file will be closed when the 'Iteratee' finishes.
enumFile :: FilePath -> Enumerator T.Text IO b
enumFile path s = Iteratee io where
	withHandle = tryStep (IO.openFile path IO.ReadMode)
	io = withHandle $ \h -> E.finally
		(runIteratee (enumHandle h s))
		(IO.hClose h)
-- | Read text from a stream and write it to a handle. If an exception
-- occurs during file IO, enumeration will stop and 'Error' will be
-- returned.
--
-- The handle should be opened with an appropriate text encoding, and
-- in 'IO.WriteMode' or 'IO.ReadWriteMode'.
iterHandle :: MonadIO m => IO.Handle -> Iteratee T.Text m ()
iterHandle h = continue step where
	step EOF = yield () EOF
	step (Chunks []) = continue step
	step (Chunks chunks) = Iteratee io where
		put = mapM_ (T.hPutStr h) chunks
		io = tryStep put (\_ -> return $ Continue step)
data Codec = Codec
	{ codecName :: T.Text
	, codecEncode :: [T.Text] -> Either E.SomeException [B.ByteString]
	, codecDecode :: B.ByteString -> Either E.SomeException (T.Text, B.ByteString)
	}

instance Show Codec where
	showsPrec d c = showParen (d > 10) $
		showString "Codec " . shows (codecName c)
encode :: Monad m => Codec -> Enumeratee T.Text B.ByteString m b
encode codec = loop where
	loop = checkDone $ continue . step
	step k EOF = yield (Continue k) EOF
	step k (Chunks []) = continue $ step k
	step k (Chunks xs) = case codecEncode codec xs of
		Left err -> throwError err
		Right byteChunks -> k (Chunks byteChunks) >>== loop
decode :: Monad m => Codec -> Enumeratee B.ByteString T.Text m b
decode codec = loop B.empty where
	dec = codecDecode codec
	
	loop acc = checkDone $ continue . step acc
	step acc k EOF = yield (Continue k) $ if B.null acc
		then EOF
		else Chunks [acc]
	step acc k (Chunks []) = continue $ step acc k
	step acc k (Chunks xs) = case dec (B.concat (acc:xs)) of
		Left err -> throwError err
		Right (text, extra) -> if T.null text
			then continue $ step extra k
			else k (Chunks [text]) >>== loop extra
utf8 :: Codec
utf8 = Codec name enc dec where
	name = T.pack "UTF-8"
	enc = Right . Prelude.map TE.encodeUtf8
	dec = unsafeTryDec . splitBytes
	splitBytes bytes = loop 0 where
		required x0
			| x0 .&. 0x80 == 0x00 = 1
			| x0 .&. 0xE0 == 0xC0 = 2
			| x0 .&. 0xF0 == 0xE0 = 3
			| x0 .&. 0xF8 == 0xF0 = 4
			
			-- Invalid input; let Text figure it out
			| otherwise           = 1
		maxN = B.length bytes
		
		loop n | n == maxN = (TE.decodeUtf8 bytes, B.empty)
		loop n = let
			req = required $ B.index bytes n
			tooLong = first TE.decodeUtf8 $ B.splitAt n bytes
			decodeMore = loop $! n + req
			in if req > maxN then tooLong else decodeMore
utf16_le :: Codec
utf16_le = Codec name enc dec where
	name = T.pack "UTF-16-LE"
	enc = Right . Prelude.map TE.encodeUtf16LE
	dec = unsafeTryDec . splitBytes
	splitBytes bytes = loop 0 where
		maxN = B.length bytes
		
		loop n |  n      == maxN = (TE.decodeUtf16LE bytes, B.empty)
		       | (n + 1) == maxN = decodeTo n
		loop n = let
			req = utf16Required (B.index bytes 0) (B.index bytes 1)
			decodeMore = loop $! n + req
			in if req > maxN then decodeTo n else decodeMore
		
		decodeTo n = first TE.decodeUtf16LE $ B.splitAt n bytes
utf16_be :: Codec
utf16_be = Codec name enc dec where
	name = T.pack "UTF-16-BE"
	enc = Right . Prelude.map TE.encodeUtf16BE
	dec = unsafeTryDec . splitBytes
	splitBytes bytes = loop 0 where
		maxN = B.length bytes
		
		loop n |  n      == maxN = (TE.decodeUtf16BE bytes, B.empty)
		       | (n + 1) == maxN = decodeTo n
		loop n = let
			req = utf16Required (B.index bytes 1) (B.index bytes 0)
			decodeMore = loop $! n + req
			in if req > maxN then decodeTo n else decodeMore
		
		decodeTo n = first TE.decodeUtf16BE $ B.splitAt n bytes
utf16Required :: Word8 -> Word8 -> Int
utf16Required x0 x1 = required where
	required = if x >= 0xD800 && x <= 0xDBFF
		then 4
		else 2
	x :: Word16
	x = (fromIntegral x1 `shiftL` 8) .|. fromIntegral x0
utf32_le :: Codec
utf32_le = Codec name enc dec where
	name = T.pack "UTF-32-LE"
	enc = Right . Prelude.map TE.encodeUtf32LE
	dec = unsafeTryDec . utf32SplitBytes TE.decodeUtf32LE
	

utf32_be :: Codec
utf32_be = Codec name enc dec where
	name = T.pack "UTF-32-BE"
	enc = Right . Prelude.map TE.encodeUtf32BE
	dec = unsafeTryDec . utf32SplitBytes TE.decodeUtf32BE
utf32SplitBytes :: (B.ByteString -> a) -> B.ByteString -> (a, B.ByteString)
utf32SplitBytes dec bytes = (dec toDecode, extra) where
	len = B.length bytes
	lenExtra = mod len 4
	lenToDecode = len - lenExtra
	(toDecode, extra) = if lenExtra == 0
		then (bytes, B.empty)
		else B.splitAt lenToDecode bytes
ascii :: Codec
ascii = Codec name (mapEither enc) dec where
	name = T.pack "ASCII"
	enc t = case T.findBy (\c -> ord c > 0x7F) t of
		Nothing -> Right . B8.pack . T.unpack $ t
		Just c -> illegalEnc name c
	dec bytes = case B.find (\w -> w > 0x7F) bytes of
		Nothing -> Right (T.pack (B8.unpack bytes), B.empty)
		Just w -> illegalDec name w
iso8859_1 :: Codec
iso8859_1 = Codec name (mapEither enc) dec where
	name = T.pack "ISO-8859-1"
	enc t = case T.findBy (\c -> ord c > 0xFF) t of
		Nothing -> Right . B8.pack . T.unpack $ t
		Just c -> illegalEnc name c
	dec bytes = Right (T.pack (B8.unpack bytes), B.empty)
illegalEnc :: T.Text -> Char -> Either E.SomeException a
illegalEnc name c = Left . E.toException . E.ErrorCall $ msg "" where
	len = Prelude.length
	pad str | len str < 4 = replicate (4 - len str) '0' ++ str
	        | otherwise      = str
	hex = "U+" ++ pad (showIntAtBase 16 (toUpper . intToDigit) (ord c) "")
	msg = (s "Codec " . shows name . s " can't encode character " . s hex)
	s = showString
illegalDec :: T.Text -> Word8 -> Either E.SomeException a
illegalDec name w = Left . E.toException . E.ErrorCall $ msg "" where
	len = Prelude.length
	pad str | len str < 2 = replicate (2 - len str) '0' ++ str
	        | otherwise      = str
	hex = "0x" ++ pad (showIntAtBase 16 (toUpper . intToDigit) w "")
	msg = (s "Codec " . shows name . s " can't decode byte " . s hex)
	s = showString
unsafeTryDec :: (a, b) -> Either E.SomeException (a, b)
unsafeTryDec (a, b) = unsafePerformIO $ do
	tried <- E.try $ E.evaluate a
	return $ case tried of
		Left err -> Left err
		Right _ -> Right (a, b)