{- |
   Quality-aware alignments

   Generally, quality data are ignored for alignment\/pattern searching
   like Smith-Waterman, Needleman-Wunsch, or BLAST(p|n|x).  I believe
   that accounting for quality will at the very least affect things like
   BLAST statistics, and e.g. is crucial for good EST annotation using Blastx.

   This module performs sequences alignments, takes quality values into
   account.

   See also <http://bioinformatics.oxfordjournals.org/cgi/content/abstract/btn052v1>.
-}
{-# LANGUAGE CPP, ParallelListComp #-}
{-# OPTIONS_GHC -fexcess-precision #-}

-- #define DEBUG

module Bio.Alignment.QAlign (
   -- * Smith-Waterman
   -- | Locally optimal alignment with affine gaps, i.e. best infix match.
   local_score, local_align

   -- * Needleman-Wunsch 
   -- | Globally optimal alignment with affine gaps, the whole sequences are matched.
   , global_score, global_align

   -- * Overlapping alignment.
   -- | The suffix of one sequence matches a prefix of another.
   , overlap_score, overlap_align

   -- * Matrix construction
   , qualMx

   -- * Interactive testing of alignments
   , test
   ) where

import Data.List (maximumBy,partition,unfoldr,zip4,tails)
import qualified Data.ByteString.Lazy as B
import qualified Data.ByteString.Lazy.Char8 as BC
import Data.Array.Unboxed

import Bio.Sequence.SeqData hiding ((!))
import Bio.Alignment.AlignData (Chr,Edit(..),SubstMx,EditList,on,isRepl,showalign)

-- | The selector must take into account the quality of the sequences
--   on Ins\/Del, the average of qualities surrounding the gap is (should be) used
type QSelector a = [(a,Edit,Qual,Qual)] -> a

-- using 21 as default corresponds to blastn default ratio (1,-3)
columns :: QSelector a -> a -> Sequence t -> Sequence t -> [[a]]
columns sel z s1 s2 = columns' sel (z,r0,c0) s1' s2'
    where (s1',s2') = (zup s1, zup s2)
          zup :: Sequence t -> [(Chr,Qual)]
          zup (Seq _ sd Nothing) = map (\c -> (c,22)) $ B.unpack sd
          zup (Seq _ sd (Just qd)) = zip (B.unpack sd) (B.unpack qd)

          -- the first row consists of increasing numbers of deletions
          r0 = map (sel . return)
               (zip4 (z:r0) (map (Del . fst) s1') (repeat (snd $ head s2')) (map snd s1'))
          -- the first column consists of increasing numbers of insertions
          c0 = map (sel . return)
               (zip4 (z:c0) (map (Ins . fst) s2') (repeat (snd $ head s1')) (map snd s2'))

columns' :: QSelector a -> (a,[a],[a]) -> [(Chr,Qual)] -> [(Chr,Qual)] -> [[a]]
columns' f (topleft,top,left) s1 s2 = let
        c0 = (topleft : left)
        -- given the previous column, and the remains of s2, calculate the next column
        mkcol (ts, p0:prev, x) = if null x then Nothing else
               let (xi,qi) = head x
                   c  = head ts : [f [del,ins,rep]
                          | del <- zip4 prev (repeat $ Del xi) (repeat qi) (avg2 $ map snd s2)
                          | ins <- zip4 c (map (Ins . fst) s2) (repeat $ head $ avg2 $ map snd x) (map snd s2)
                          | rep <- zip4 (p0:prev) (map (Repl xi . fst) s2) (repeat qi) (map snd s2)]
               in Just (c, (tail ts, c, tail x))
    in c0 : unfoldr mkcol (top,c0,s1)

avg2 :: [Qual] -> [Qual]
avg2 = map f . tails
    where f (x1:x2:_) = (x1+x2) `div` 2
          f [x] = x
          f _ = error "Nasty - incorrect column lenght"



-- | Minus infinity (or an approximation thereof)
minf :: Double
minf = -100000000

type QualMx t a = Qual -> Qual -> SubstMx t a

qualMx :: Qual -> Qual -> (Chr,Chr) -> Double
qualMx q1 q2 (x,y) = if isN x || isN y then 0.0 else
#ifdef DEBUG
                        if q1 < 0 || q1 > 99 || q2 < 0 || q2 > 99
                        then error ("Qualities out of range: "++show (q1,q2))
                        else
#endif
                        if x==y || x+32==y || x-32==y
                        then matchtbl!(q1,q2) else mismatchtbl!(q1,q2)
  where matchtbl, mismatchtbl :: UArray (Qual,Qual) Double
        matchtbl = array ((0,0),(99,99))    [((x,y),adjust True x y) | x <- [0..99], y <- [0..99]]
        mismatchtbl = array ((0,0),(99,99)) [((x,y),adjust False x y) | x <- [0..99], y <- [0..99]]
        isN c = c `elem` [78,88,110,120]

adjust :: Bool -> Qual -> Qual -> Double
adjust s q1 q2 =
    let fromQual x = 10**(-fromIntegral x/10)
        e1 = fromQual q1
        e2 = fromQual q2
        e  = (e1+e2-4/3*e1*e2)
    in logBase 2 (if s then 4*(1-e) else 4/3*e)

-- ------------------------------------------------------------
-- Edit distances

-- | Calculate global edit distance (Needleman-Wunsch alignment score)
global_score :: QualMx t Double -> (Double,Double) -> Sequence t -> Sequence t -> Double
global_score mx g s1 s2 = uncurry max . last . last
                          $ columns (score_select minf mx g) (0,fst g) s1 s2

-- | Calculate local edit distance (Smith-Waterman alignment score)
local_score :: QualMx t Double -> (Double,Double) -> Sequence t -> Sequence t -> Double
local_score mx g s1 s2 = maximum . map (uncurry max) . concat
                         $ columns (score_select 0 mx g) (0,fst g) s1 s2

-- | Calucalte best overlap score, where gaps at the edges are free
--   The starting point is like for local score (0 cost for initial indels),
--   the result is the maximum anywhere in the last column or bottom row of the matrix.

-- Oh. local score_select will not work, since we cannot replace any matrix entry
-- with zero in order to initiate a new alignment.  So we need 'minf', except in the
-- initial row/column.  Damn.
overlap_score :: QualMx t Double -> (Double,Double) -> Sequence t -> Sequence t -> Double
overlap_score mx g s1 s2 = maximum $ map (uncurry max) $ sel cols
    where cols   = columns overlap_score_select (0,fst g) s1 s2 
          sel cs = map last cs ++ last cs
          -- well - edges have less than three options, so we can set them to zero
          overlap_score_select cds@[_,_,_] = score_select minf mx g cds
          overlap_score_select [_] = (0,minf)

-- | Generic scoring and selection function for global and local scoring
score_select :: Double -> QualMx t Double -> (Double,Double) -> QSelector (Double,Double)
score_select minf mx (go,ge) cds =
    let (reps,ids) = partition (isRepl.snd') cds
        s = maximum $ minf:[max sub gap + mx q1 q2 (x,y) | ((sub,gap),Repl x y,q1,q2) <- reps]
        g = maximum $ minf:[max (sub+go) (gap+ge) | ((sub,gap),_op,_q1,_q2) <- ids]
    in s `seq` g `seq` (s,g) -- (s,g) 
             -- seq makes local slower, but overlap faster(!?)

-- ------------------------------------------------------------
-- Alignments (rip from AAlign)

-- maximum...
max' :: (Double,EditList) -> (Double,EditList) -> (Double,EditList)
max' (x,ax) (y,yx) = if x>=y then (x,ax) else (y,yx)

-- ... and addition for compound values
fp :: (Double,EditList) -> (Double,Edit) -> (Double,EditList)
fp (x,ax) (s,e) = (x+s,e:ax)

-- | Calculate global alignment (Needleman-Wunsch)
global_align :: QualMx t Double -> (Double,Double) -> Sequence t -> Sequence t -> (Double,EditList)
global_align mx g s1 s2 = revsnd . uncurry max' . last . last
               $ columns (align_select minf mx g) ((0,[]),(fst g,[])) s1 s2

-- | Calculate local alignment (Smith-Waterman)
--   (can we replace uncurry max' with fst - a local alignment must always end on a subst, no?)
local_align :: QualMx t Double -> (Double,Double) -> Sequence t -> Sequence t -> (Double, EditList)
local_align mx g s1 s2 = revsnd . maximumBy (compare `on` fst)
                         . map (uncurry max') . concat
                         $ columns (align_select 0 mx g) ((0,[]),(fst g,[])) s1 s2

-- | Calucalte best overlap score, where gaps at the edges are free
--   The starting point is like for local score (0 cost for initial indels),
--   the result is the maximum anywhere in the last column or bottom row of the matrix.
overlap_align :: QualMx t Double -> (Double,Double) -> Sequence t -> Sequence t -> (Double,EditList)
overlap_align mx g s1 s2 = revsnd . maximumBy (compare `on` fst) . map (uncurry max') $ sel cols
    where cols   = columns overlap_align_select ((0,[]),(minf,[])) s1 s2 
          sel cs = map last cs ++ last cs
          -- again, the edges have less than three options, so we can set them to zero
          overlap_align_select cds@[_,_,_] = align_select minf mx g cds
          overlap_align_select [_] = ((0,[]),(minf,[]))

-- | Variant that retains indels to retain the entire sequence in the result
overlap_align' :: QualMx t Double -> (Double,Double) -> Sequence t -> Sequence t -> (Double,EditList)
overlap_align' mx g s1 s2 = revsnd . maximumBy (compare `on` fst) . map (uncurry max') $ sel cols
    where cols   = columns overlap_align_select ((0,[]),(fst g,[])) s1 s2 
          sel cs = map last cs ++ last cs
          -- the old "fewer choices" trick
          overlap_align_select cds@[_,_,_] = align_select minf mx g cds
          overlap_align_select [(((s1,es1),(s2,es2)),e,_,_)] = ((0,e:es2),(0,e:es2))

-- (maybe better to reverse the inputs for global?)
revsnd (s,a) = (s,reverse a)

-- | Generic scoring and selection for global and local alignment
align_select :: Double -> QualMx t Double -> (Double,Double) -> QSelector ((Double,EditList),(Double,EditList))
align_select minf mx (go,ge) cds =
    let (reps,ids) = partition (isRepl.snd') cds
        s = maximumBy (compare `on` fst)
          $ (minf,[]):[max' sub gap `fp` (mx q1 q2 (x,y),e) | ((sub,gap),e@(Repl x y),q1,q2) <- reps]
        g = maximumBy (compare `on` fst)
          $ (minf,[]):[max' (sub `fp` (go,e)) (gap `fp` (ge,e)) | ((sub,gap),e,_q1,_q2) <- ids]
    in (s,g)

snd' (_,x,_,_) = x

test :: IO ()
test = do
  putStrLn "Enter two strings:"
  s1' <- getLine
  s2' <- getLine
  let s1 = Seq (BC.pack "foo") (BC.pack s1') Nothing
      s2 = Seq (BC.pack "bar") (BC.pack s2') Nothing
      mx = qualMx
      g  = (-5,-2)
  let ga = global_align mx g s1 s2
      la = local_align mx g s1 s2
      oa = overlap_align mx g s1 s2
      or = overlap_align mx g (revcompl s1) (revcompl s2)
  putStrLn ("GLOBAL:   " ++ show (fst ga))
  putStrLn $ showalign $ snd ga
  putStrLn ("OVERLAP:  " ++ show (fst oa))
  putStrLn $ showalign $ snd oa
  putStrLn ("OVERLAP (rc):  " ++ show (fst or))
  putStrLn $ showalign $ snd or
  putStrLn ("LOCAL:    " ++ show (fst la))
  putStrLn $ showalign $ snd la