module Codec.Image.PBM
  ( PBM(..)
  
  , encodePBM
  , encodePlainPBM
  , EncodeError(..)
  , encodePBM'
  
  , DecodeError(..)
  , decodePBM
  , decodePlainPBM
  , decodePBMs
  
  , padPBM
  , trimPBM
  , repadPBM
  ) where
import Data.Bits (shiftL, shiftR, (.&.))
import Data.Ix (range)
import Data.Word (Word8)
import qualified Data.Array.Unboxed as U
import qualified Data.ByteString as BS
import qualified Compat as BSC
import Data.Bits.Bitwise (fromListBE, toListLE)
import Data.Array.BitArray (BitArray, bounds, elems, listArray, false, (//), assocs, ixmap)
import Data.Array.BitArray.ByteString (toByteString, fromByteString)
data PBM = PBM{ pbmWidth :: !Int, pbmPixels :: !(BitArray (Int, Int)) }
encodePBM :: BitArray (Int, Int)  -> BS.ByteString
encodePBM pixels = case encodePBM' pbm of
  Right string -> string
  _ -> error "Codec.Image.PBM.encodePBM: internal error"
  where
    ((_, xlo), (_, xhi)) = bounds pixels
    width  = xhi  xlo + 1
    pbm = padPBM PBM{ pbmWidth = width, pbmPixels = pixels }
data EncodeError
  = BadPixelWidth{ encErrPBM :: PBM } 
  | BadSmallWidth{ encErrPBM :: PBM } 
  | BadLargeWidth{ encErrPBM :: PBM } 
encodePlainPBM :: BitArray (Int, Int)  -> String
encodePlainPBM pixels = unlines (header : raster)
  where
    ((ylo, xlo), (yhi, xhi)) = bounds pixels
    width  = xhi  xlo + 1
    height = yhi  ylo + 1
    header = "P1\n" ++ show width ++ " " ++ show height
    raster = concatMap (chunk 64) . chunk width . map char . elems $ pixels
    char False = '0'
    char True  = '1'
    chunk n _ | n <= 0 = error "Codec.Image.PBM.encodePlainPBM: internal error"
    chunk _ [] = []
    chunk n xs = let (ys, zs) = splitAt n xs in ys : chunk n zs
encodePBM' :: PBM -> Either EncodeError BS.ByteString
encodePBM' pbm
  | (pixelWidth .&. 7) /= 0 = Left (BadPixelWidth pbm)
  | width <= pixelWidth  8 = Left (BadSmallWidth pbm)
  | width >  pixelWidth     = Left (BadLargeWidth pbm)
  | otherwise = Right (header `BS.append` raster)
  where
    width = pbmWidth pbm
    pixels = pbmPixels pbm
    ((ylo, xlo), (yhi, xhi)) = bounds pixels
    pixelWidth  = xhi  xlo + 1
    pixelHeight = yhi  ylo + 1
    height = pixelHeight
    header = BS.pack $ map (toEnum . fromEnum) headerStr
    headerStr = "P4\n" ++ show width ++ " " ++ show height ++ "\n"
    raster = reverseByteBits (toByteString pixels)
data DecodeError a
  = BadMagicP a 
  | BadMagicN a 
  | BadWidth  a 
  | BadHeight a 
  | BadSpace  a 
  | BadPixels a 
  deriving (Eq, Ord, Read, Show)
decodePBM :: BS.ByteString -> Either (DecodeError BS.ByteString) (PBM, BS.ByteString)
decodePBM s =                    case BSC.uncons s of
  Just (cP, s) | cP == char 'P' -> case BSC.uncons s of
    Just (c4, s) | c4 == char '4' -> case int (skipSpaceComment s) of
      Just (iw, s) | iw > 0         -> case int (skipSpaceComment s) of
        Just (ih, s) | ih > 0         -> case skipSingleSpace s of
          Just s                        ->
            let rowBytes = (iw + 7) `shiftR` 3
                imgBytes = ih * rowBytes
            in                             case BS.splitAt imgBytes s of
            (raster, s) | BS.length raster == imgBytes ->
              let ibs = ((0, 0), (ih  1, (rowBytes `shiftL` 3)  1))
              in  Right (PBM{ pbmWidth = iw, pbmPixels = fromByteString ibs (reverseByteBits raster) }, s)
            _ -> Left (BadPixels s)
          _ -> Left (BadSpace s)
        _ -> Left (BadHeight s)
      _ -> Left (BadWidth s)
    _ -> Left (BadMagicN s)
  _ -> Left (BadMagicP s)
  where
    skipSpaceComment t = case (\t -> (t, BSC.uncons t)) (BS.dropWhile isSpace t) of
      (_, Just (cH, t)) | cH == char '#' -> case BSC.uncons (BS.dropWhile (/= char '\n') t) of
        Just (cL, t) | cL == char '\n' -> skipSpaceComment t
        _ -> Left (BadSpace t)
      (t, _) -> Right t
    skipSingleSpace t = case BSC.uncons t of
      Just (cS, t) | isSpace cS -> Just t
      _ -> Nothing
    int (Left _) = Nothing
    int (Right t) = case BS.span isDigit t of
      (d, t)
        | BS.length d > 0 &&
          fmap ((/= char '0') . fst) (BSC.uncons d) == Just True -> case reads (map unchar $ BS.unpack d) of
            [(d, "")] -> Just (d, t)
            _ -> Nothing
      _ -> Nothing
    isSpace c = c `elem` map char pbmSpace
    isDigit c = c `elem` map char "0123456789"
    char = toEnum . fromEnum
    unchar = toEnum . fromEnum
decodePBMs :: BS.ByteString -> ([PBM], Maybe (DecodeError BS.ByteString))
decodePBMs s
  | BS.null s = ([], Nothing)
  | otherwise = case decodePBM s of
      Left err -> ([], Just err)
      Right (pbm, s) -> prepend pbm (decodePBMs s)
  where
    prepend pbm (pbms, merr) = (pbm:pbms, merr)
decodePlainPBM :: String -> Either (DecodeError String) (PBM, String)
decodePlainPBM s = case s of
  ('P':s) -> case s of
    ('1':s) -> case int (skipSpaceComment s) of
      Just (iw, s) | iw > 0 -> case int (skipSpaceComment s) of
        Just (ih, s) | ih > 0 -> case collapseRaster (iw * ih) s of
          Just (raster, s) ->
            let ibs = ((0, 0), (ih  1, iw  1))
            in  Right (PBM{ pbmWidth = iw, pbmPixels = listArray ibs raster }, s)
          _ -> Left (BadPixels s)
        _ -> Left (BadHeight s)
      _ -> Left (BadWidth s)
    _ -> Left (BadMagicN s)
  _ -> Left (BadMagicP s)
  where
    skipSpaceComment t = case dropWhile isSpace t of
      ('#':t) -> case dropWhile (/= '\n') t of
        ('\n':t) -> skipSpaceComment t
        _ -> Left (BadSpace t)
      t -> Right t
    int (Left _) = Nothing
    int (Right t) = case span isDigit t of
      (d@(d0:_), t) | d0 /= '0' -> case reads d of
        [(d, "")] -> Just (d, t)
        _ -> Nothing
      _ -> Nothing
    collapseRaster 0 t = Just ([], t)
    collapseRaster n t = case dropWhile isSpace t of
      ('0':t) -> prepend False (collapseRaster (n  1) t)
      ('1':t) -> prepend True  (collapseRaster (n  1) t)
      _ -> Nothing
    prepend _ Nothing = Nothing
    prepend b (Just (bs, t)) = Just (b:bs, t)
    isSpace c = c `elem` pbmSpace
    isDigit c = c `elem` "0123456789"
padPBM :: PBM -> PBM
padPBM pbm
  | (pixelWidth .&. 7) == 0 = pbm
  | otherwise = pbm{ pbmPixels = false paddedBounds // assocs (pbmPixels pbm) }
  where
    ((ylo, xlo), (yhi, xhi)) = bounds (pbmPixels pbm)
    pixelWidth = xhi  xlo + 1
    rowBytes = (pixelWidth + 7) `shiftR` 3
    paddedWidth = rowBytes `shiftL` 3
    paddedBounds = ((ylo, xlo), (yhi, xhi'))
    xhi' = paddedWidth + xlo  1
trimPBM :: PBM -> Maybe PBM
trimPBM pbm
  | pbmWidth pbm > pixelWidth = Nothing
  | pbmWidth pbm == pixelWidth = Just pbm
  | otherwise = Just pbm{ pbmPixels = ixmap trimmedBounds id (pbmPixels pbm) }
  where
    ((ylo, xlo), (yhi, xhi)) = bounds (pbmPixels pbm)
    pixelWidth = xhi  xlo + 1
    trimmedBounds = ((ylo, xlo), (yhi, xhi'))
    xhi' = pbmWidth pbm + xlo  1
repadPBM :: PBM -> Maybe PBM
repadPBM pbm = padPBM `fmap` trimPBM pbm
reverseByteBits :: BS.ByteString -> BS.ByteString
reverseByteBits = BS.map reverseBits
reverseBits :: Word8 -> Word8
reverseBits w = bitReversed U.! w
bitReversed :: U.UArray Word8 Word8
bitReversed = U.listArray bs [ bitReverse w | w <- range bs ]
  where bs = (minBound, maxBound)
bitReverse :: Word8 -> Word8
bitReverse = fromListBE . toListLE
pbmSpace :: String
pbmSpace = " \t\n\v\f\r"