{- |
   Quality-aware alignments

   Generally, quality data are ignored for alignment\/pattern searching
   like Smith-Waterman, Needleman-Wunsch, or BLAST(p|n|x).  I believe
   that accounting for quality will at the very least affect things like
   BLAST statistics, and e.g. is crucial for good EST annotation using Blastx.

   This module performs sequences alignments, takes quality values into
   account.
-}
{-# LANGUAGE CPP, ParallelListComp #-}
{-# OPTIONS_GHC -fexcess-precision #-}

-- #define DEBUG

module Bio.Alignment.QAlign (
   -- * Smith-Waterman, or locally optimal alignment with affine gaps
   local_score, local_align

   -- * Needleman-Wunsch, or globally optimal alignment with affine gaps
   , global_score, global_align

   -- * Matrix construction
   , qualMx
   ) where

import Data.List (maximumBy,partition,unfoldr,zip4,tails)
import qualified Data.ByteString.Lazy as B
import Data.Array.Unboxed

import Bio.Sequence.SeqData hiding ((!))
import Bio.Alignment.AlignData (Chr,Edit(..),SubstMx,EditList,on,isRepl,showalign)

-- | The selector must take into account the quality of the sequences
--   on Ins\/Del, the average of qualities surrounding the gap is (should be) used
type QSelector a = [(a,Edit,Qual,Qual)] -> a

-- using 21 as default corresponds to blastn default ratio (1,-3)
columns :: QSelector a -> a -> Sequence -> Sequence -> [[a]]
columns sel z s1 s2 = columns' sel (z,r0,c0) s1' s2'
    where (s1',s2') = (zup s1, zup s2)
          zup :: Sequence -> [(Chr,Qual)]
          zup (Seq _ sd Nothing) = map (\c -> (c,22)) $ B.unpack sd
          zup (Seq _ sd (Just qd)) = zip (B.unpack sd) (B.unpack qd)

          -- the first row consists of increasing numbers of deletions
          r0 = map (sel . return)
               (zip4 (z:r0) (map (Del . fst) s1') (repeat (snd $ head s2')) (map snd s1'))
          -- the first column consists of increasing numbers of insertions
          c0 = map (sel . return)
               (zip4 (z:c0) (map (Ins . fst) s2') (repeat (snd $ head s1')) (map snd s2'))

columns' :: QSelector a -> (a,[a],[a]) -> [(Chr,Qual)] -> [(Chr,Qual)] -> [[a]]
columns' f (topleft,top,left) s1 s2 = let
        c0 = (topleft : left)
        -- given the previous column, and the remains of s2, calculate the next column
        mkcol (ts, p0:prev, x) = if null x then Nothing else
               let (xi,qi) = head x
                   c  = head ts : [f [del,ins,rep]
                          | del <- zip4 prev (repeat $ Del xi) (repeat qi) (avg2 $ map snd s2)
                          | ins <- zip4 c (map (Ins . fst) s2) (repeat $ head $ avg2 $ map snd x) (map snd s2)
                          | rep <- zip4 (p0:prev) (map (Repl xi . fst) s2) (repeat qi) (map snd s2)]
               in Just (c, (tail ts, c, tail x))
    in c0 : unfoldr mkcol (top,c0,s1)

avg2 :: [Qual] -> [Qual]
avg2 = map f . tails
    where f (x1:x2:_) = (x1+x2) `div` 2
          f [x] = x
          f _ = error "Nasty - incorrect column lenght"



-- | Minus infinity (or an approximation thereof)
minf :: Double
minf = -100000000

type QualMx a = Qual -> Qual -> SubstMx a

qualMx :: Qual -> Qual -> (Chr,Chr) -> Double
qualMx q1 q2 (x,y) = if isN x || isN y then 0.0 else
#ifdef DEBUG
                        if q1 < 0 || q1 > 99 || q2 < 0 || q2 > 99
                        then error ("Qualities out of range: "++show (q1,q2))
                        else
#endif
                        if x==y || x+32==y || x-32==y
                        then matchtbl!(q1,q2) else mismatchtbl!(q1,q2)
  where matchtbl, mismatchtbl :: UArray (Qual,Qual) Double
        matchtbl = array ((0,0),(99,99))    [((x,y),adjust True x y) | x <- [0..99], y <- [0..99]]
        mismatchtbl = array ((0,0),(99,99)) [((x,y),adjust False x y) | x <- [0..99], y <- [0..99]]
        isN c = c `elem` [78,88,110,120]

adjust :: Bool -> Qual -> Qual -> Double
adjust s q1 q2 =
    let fromQual x = 10**(-fromIntegral x/10)
        e1 = fromQual q1
        e2 = fromQual q2
        e  = (e1+e2-4/3*e1*e2)
    in logBase 2 (if s then 4*(1-e) else 4/3*e)

-- ------------------------------------------------------------
-- Edit distances

-- | Calculate global edit distance (Needleman-Wunsch alignment score)
global_score :: QualMx Double -> (Double,Double) -> Sequence -> Sequence -> Double
global_score mx g s1 s2 = uncurry max . last . last
                          $ columns (score_select minf mx g) (0,fst g) s1 s2

-- | Calculate local edit distance (Smith-Waterman alignment score)
local_score :: QualMx Double -> (Double,Double) -> Sequence -> Sequence -> Double
local_score mx g s1 s2 = maximum . map (uncurry max) . concat
                         $ columns (score_select 0 mx g) (0,fst g) s1 s2

-- | Generic scoring and selection function for global and local scoring
score_select :: Double -> QualMx Double -> (Double,Double) -> QSelector (Double,Double)
score_select minf mx (go,ge) cds =
    let (reps,ids) = partition (isRepl.snd') cds
        s = maximum $ minf:[max sub gap + mx q1 q2 (x,y) | ((sub,gap),Repl x y,q1,q2) <- reps]
        g = maximum $ minf:[max (sub+go) (gap+ge) | ((sub,gap),_op,_q1,_q2) <- ids]
    in (s,g)

-- ------------------------------------------------------------
-- Alignments (rip from AAlign)

-- maximum...
max' :: (Double,EditList) -> (Double,EditList) -> (Double,EditList)
max' (x,ax) (y,yx) = if x>=y then (x,ax) else (y,yx)

-- ... and addition for compound values
fp :: (Double,EditList) -> (Double,Edit) -> (Double,EditList)
fp (x,ax) (s,e) = (x+s,e:ax)

-- | Calculate global alignment (Needleman-Wunsch)
global_align :: QualMx Double -> (Double,Double) -> Sequence -> Sequence -> (Double,EditList)
global_align mx g s1 s2 = revsnd . uncurry max' . last . last
               $ columns (align_select minf mx g) ((0,[]),(fst g,[])) s1 s2

-- | Calculate local alignmnet (Smith-Waterman)
local_align :: QualMx Double -> (Double,Double) -> Sequence -> Sequence -> (Double, EditList)
local_align mx g s1 s2 = revsnd . maximumBy (compare `on` fst)
                         . map (uncurry max') . concat
                         $ columns (align_select 0 mx g) ((0,[]),(fst g,[])) s1 s2

{-
-- | Calculate the optimal match between a suffix of one sequence and a prefix
--   of the other (useful for e.g. sequence assembly)
overlap_align mx g s1 s2 =
     columns (align_select minf mx g) ((0,[]),(fst g,[])) s1 s2
-}

-- (maybe better to reverse the inputs for global?)
revsnd (s,a) = (s,reverse a)

-- | Generic scoring and selection for global and local alignment
align_select :: Double -> QualMx Double -> (Double,Double) -> QSelector ((Double,EditList),(Double,EditList))
align_select minf mx (go,ge) cds =
    let (reps,ids) = partition (isRepl.snd') cds
        s = maximumBy (compare `on` fst)
          $ (minf,[]):[max' sub gap `fp` (mx q1 q2 (x,y),e) | ((sub,gap),e@(Repl x y),q1,q2) <- reps]
        g = maximumBy (compare `on` fst)
          $ (minf,[]):[max' (sub `fp` (go,e)) (gap `fp` (ge,e)) | ((sub,gap),e,_q1,_q2) <- ids]
    in (s,g)

snd' (_,x,_,_) = x