module Bio.Alignment.Soap ( SoapAlign(..), SoapAlignMismatch(..)
                          , refSeqPos, refCSeqLoc, refSeqLoc, mismatchSeqPos
                          , parse, unparse, parseMismatch, unparseMismatch
                          , group
                          )
    where 

import Prelude hiding (length)
import Control.Monad.Error
import qualified Data.ByteString.Lazy as LBSW
import qualified Data.ByteString.Lazy.Char8 as LBS
import Data.Char
import qualified Data.List as List (length)
import Data.List (groupBy)

import qualified Bio.Location.ContigLocation as CLoc
import qualified Bio.Location.Location as Loc
import qualified Bio.Location.Position as Pos
import Bio.Location.OnSeq
import qualified Bio.Location.SeqLocation as SeqLoc
import Bio.Location.Strand
import Bio.Sequence.SeqData


-- | Alignment output from SOAP
data SoapAlign = SA { name :: !SeqName
                    , sequ :: !SeqData     -- ^ Reference strand orientation sequence
                    , qual :: !QualData    -- ^ Reference strand orientation quality data
                    , nhit :: !Int
                    , pairend :: !Char
                    , length :: !Offset
                    , strand :: !Strand
                    , refname :: !SeqName
                    , refstart :: !Offset  -- ^ 1-based index, as output by SOAP, of reference strand 5' end
                    , nmismatch :: !Int
                    , mismatches :: ![SoapAlignMismatch]
                    } deriving (Read, Show, Eq, Ord)

data SoapAlignMismatch = SAM { readnt :: !Char   -- ^ Read nt in reference strand orientation
                             , refnt :: !Char    -- ^ Reference nt in reference strand orientation
                             , offset :: !Offset -- ^ Offset from reference strand 5' end in reference strand orientation
                             , qualnt :: !Qual   -- ^ Quality score of read nt
                             } deriving (Read, Show, Eq, Ord)

mismatchSeqPos :: SoapAlign -> SoapAlignMismatch -> SeqLoc.SeqPos
mismatchSeqPos sa sam = OnSeq (refname sa) $ Pos.Pos (refstart sa + offset sam - 1) (strand sa)

refSeqPos :: SoapAlign -> SeqLoc.SeqPos
refSeqPos sa = OnSeq (refname sa) $ CLoc.startPos $ refCLoc sa

refCSeqLoc :: SoapAlign -> SeqLoc.ContigSeqLoc
refCSeqLoc sa = OnSeq (refname sa) (refCLoc sa)

refCLoc :: SoapAlign -> CLoc.ContigLoc
refCLoc sa = CLoc.ContigLoc (refstart sa - 1) (length sa) (strand sa)

refSeqLoc :: SoapAlign -> SeqLoc.SeqLoc
refSeqLoc sa = OnSeq (refname sa) (Loc.Loc [ refCLoc sa ])

qualScale :: Qual
qualScale = 64

parse :: (Error e, MonadError e m) => LBS.ByteString -> m SoapAlign
parse bstr
    = case LBS.split '\t' bstr of
        (nameStr:sequStr:qualStr:nhitStr:pairendStr:lengthStr:strandStr:refnameStr:refstartStr:nmismatchStr:mismatchesStrs)
            -> do let scQualStr = LBSW.map (subtract qualScale) qualStr
                  nhitInt <- parseIntBStr nhitStr
                  pairendCh <- parseCharBStr pairendStr
                  lengthOff <- parseOffsetBStr lengthStr
                  strandVal <- parseCharBStr strandStr >>= parseStrandChar
                  refstartOff <- parseOffsetBStr refstartStr
                  nmismatchInt <- parseIntBStr nmismatchStr
                  mismatchesList <- mapM parseMismatch mismatchesStrs
                  verifyNMismatches nmismatchInt mismatchesList
                  return $ SA nameStr sequStr scQualStr nhitInt pairendCh lengthOff strandVal refnameStr refstartOff nmismatchInt mismatchesList
        fs -> throwError $ strMsg $ "Bio.Alignment.Soap.parse: too few fields in soapAlign: " ++ show (List.length fs)
    where verifyNMismatches n l = unless (List.length l == n) $ throwError $ strMsg $
                                  "Bio.Alignment.Soap.parse: wrong number of errors: " ++ show (n, List.length l)
          parseStrandChar '+' = return Fwd
          parseStrandChar '-' = return RevCompl
          parseStrandChar ch  = throwError $ strMsg $ "Unknown strand " ++ show ch

parseMismatch :: (Error e, MonadError e m) => LBS.ByteString -> m SoapAlignMismatch
parseMismatch mstr = do (ref, rest1) <- maybe malformed return $ LBS.uncons mstr
                        unless (LBS.isPrefixOf (LBS.pack "->") rest1) malformed
                        let (offstr, rest2) = LBS.span isDigit $ LBS.drop 2 rest1
                        (rd, qualstr) <- maybe malformed return $ LBS.uncons rest2
                        o <- parseOffsetBStr offstr
                        q <- liftM fromIntegral $ parseIntBStr qualstr
                        return $ SAM rd ref o q
    where malformed = throwError $ strMsg $ "Bio.Alignment.Soap.parseMismatch: malformed mismatch field " 
                      ++ (show $ LBS.unpack mstr)

parseIntBStr :: (Error e, MonadError e m) => LBS.ByteString -> m Int
parseIntBStr zstr = case LBS.readInt zstr of
                      Just (z, rest) | LBS.null rest -> return z
                      _ -> throwError $ strMsg $ "parseIntBStr: Malformed int field " ++ show zstr

parseOffsetBStr :: (Error e, MonadError e m) => LBS.ByteString -> m Offset
parseOffsetBStr zstr = case LBS.readInteger zstr of
                         Just (z, rest) | LBS.null rest -> return $ fromIntegral z
                         _ -> throwError $ strMsg $ "parseOffsetBStr: Malformed integer field " ++ show zstr

parseCharBStr :: (Error e, MonadError e m) => LBS.ByteString -> m Char
parseCharBStr chstr | LBS.length chstr == 1 = return $ LBS.head chstr
                    | otherwise = throwError $ strMsg $ "parseCharBStr: Malformed char field " ++ show chstr

unparse :: SoapAlign -> LBS.ByteString
unparse sa = LBS.intercalate (LBS.singleton '\t') $ [ name sa
                                                    , sequ sa
                                                    , LBSW.map (+ qualScale) $ qual sa
                                                    , LBS.pack $ show $ nhit sa
                                                    , LBS.singleton $ pairend sa
                                                    , LBS.pack $ show $ length sa
                                                    , unparseStrand $ strand sa
                                                    , refname sa
                                                    , LBS.pack $ show $ refstart sa
                                                    , LBS.pack $ show $ nmismatch sa
                                                    ] ++ map unparseMismatch (mismatches sa)
    where unparseStrand Fwd      = LBS.singleton '+'
          unparseStrand RevCompl = LBS.singleton '-'

unparseMismatch :: SoapAlignMismatch -> LBS.ByteString
unparseMismatch (SAM rd ref off q) = LBS.concat [ LBS.singleton ref
                                                , LBS.pack "->"
                                                , LBS.pack $ show off
                                                , LBS.singleton rd
                                                , LBS.pack $ show q
                                                ]

group :: [SoapAlign] -> [[SoapAlign]]
group = groupBy (equating name)
    where equating f x y = (f x) == (f y)