{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE BangPatterns #-}
module Bio.Motif.Alignment
    ( alignment
    , alignmentBy
    , linPenal
    , quadPenal
    , cubPenal
    , expPenal
    , l1
    , l2
    , l3
    , lInf
    , AlignFn
    , CombineFn
    ) where

import qualified Data.Vector.Generic as G
import qualified Data.Vector.Unboxed as U
import qualified Data.Matrix.Unboxed as M
import Statistics.Sample (mean)

import Bio.Motif
import Bio.Utils.Functions

-- | penalty function takes the number of gaps and matched positions as input,
-- return penalty value
type PenalFn = Int -> Int -> Double

type DistanceFn = forall v. (G.Vector v Double, G.Vector v (Double, Double))
               => v Double -> v Double -> Double

type AlignFn = PWM
            -> PWM
            -> (Double, (Bool, Int))  -- ^ (distance, (on same direction,
                                      -- position w.r.t. the first pwm))

-- | combine distances from different positions of alignment
type CombineFn = U.Vector Double -> Double

alignment :: AlignFn
alignment :: AlignFn
alignment = DistanceFn -> PenalFn -> CombineFn -> AlignFn
alignmentBy DistanceFn
jsd (Double -> PenalFn
expPenal Double
0.05) CombineFn
l1

-- | linear penalty
linPenal :: Double -> PenalFn
linPenal :: Double -> PenalFn
linPenal Double
x Int
nGap Int
nMatch = Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
nGap Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
x Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
nMatch
{-# INLINE linPenal #-}

-- | quadratic penalty
quadPenal :: Double -> PenalFn
quadPenal :: Double -> PenalFn
quadPenal Double
x Int
nGap Int
nMatch = Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int
nGap Int -> Int -> Int
forall a b. (Num a, Integral b) => a -> b -> a
^ (Int
2 :: Int)) Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
x Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
nMatch
{-# INLINE quadPenal #-}

-- | cubic penalty
cubPenal :: Double -> PenalFn
cubPenal :: Double -> PenalFn
cubPenal Double
x Int
nGap Int
nMatch = Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int
nGap Int -> Int -> Int
forall a b. (Num a, Integral b) => a -> b -> a
^ (Int
3 :: Int)) Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
x Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
nMatch
{-# INLINE cubPenal #-}

-- | exponentail penalty
expPenal :: Double -> PenalFn
expPenal :: Double -> PenalFn
expPenal Double
x Int
nGap Int
nMatch = Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int
2Int -> Int -> Int
forall a b. (Num a, Integral b) => a -> b -> a
^Int
nGap Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1 :: Int) Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
x Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
nMatch
{-# INLINE expPenal #-}

l1 :: CombineFn
l1 :: CombineFn
l1 = CombineFn
forall (v :: * -> *). Vector v Double => v Double -> Double
mean
{-# INLINE l1 #-}

l2 :: CombineFn
l2 :: CombineFn
l2 = Double -> Double
forall a. Floating a => a -> a
sqrt (Double -> Double) -> CombineFn -> CombineFn
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CombineFn
forall (v :: * -> *). Vector v Double => v Double -> Double
mean CombineFn -> (Vector Double -> Vector Double) -> CombineFn
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Double -> Double) -> Vector Double -> Vector Double
forall a b. (Unbox a, Unbox b) => (a -> b) -> Vector a -> Vector b
U.map (Double -> Double -> Double
forall a. Floating a => a -> a -> a
**Double
2)
{-# INLINE l2 #-}

l3 :: CombineFn
l3 :: CombineFn
l3 = (Double -> Double -> Double
forall a. Floating a => a -> a -> a
**(Double
1Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/Double
3)) (Double -> Double) -> CombineFn -> CombineFn
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CombineFn
forall (v :: * -> *). Vector v Double => v Double -> Double
mean CombineFn -> (Vector Double -> Vector Double) -> CombineFn
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Double -> Double) -> Vector Double -> Vector Double
forall a b. (Unbox a, Unbox b) => (a -> b) -> Vector a -> Vector b
U.map (Double -> Double -> Double
forall a. Floating a => a -> a -> a
**Double
3)
{-# INLINE l3 #-}

lInf :: CombineFn
lInf :: CombineFn
lInf = CombineFn
forall a. (Unbox a, Ord a) => Vector a -> a
U.maximum
{-# INLINE lInf #-}

-- internal gaps are not allowed, larger score means larger distance, so the smaller the better
alignmentBy :: DistanceFn  -- ^ compute the distance between two aligned pwms
            -> PenalFn     -- ^ gap penalty
            -> CombineFn
            -> AlignFn
alignmentBy :: DistanceFn -> PenalFn -> CombineFn -> AlignFn
alignmentBy DistanceFn
fn PenalFn
pFn CombineFn
combFn PWM
m1 PWM
m2
    | (Double, Int) -> Double
forall a b. (a, b) -> a
fst (Double, Int)
forwardAlign Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
<= (Double, Int) -> Double
forall a b. (a, b) -> a
fst (Double, Int)
reverseAlign =
        ((Double, Int) -> Double
forall a b. (a, b) -> a
fst (Double, Int)
forwardAlign, (Bool
True, (Double, Int) -> Int
forall a b. (a, b) -> b
snd (Double, Int)
forwardAlign))
    | Bool
otherwise = ((Double, Int) -> Double
forall a b. (a, b) -> a
fst (Double, Int)
reverseAlign, (Bool
False, (Double, Int) -> Int
forall a b. (a, b) -> b
snd (Double, Int)
reverseAlign))
  where
    forwardAlign :: (Double, Int)
forwardAlign | Double
d1 Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
< Double
d2 = (Double
d1,Int
i1)
                 | Bool
otherwise = (Double
d2,-Int
i2)
      where
        (Double
d1,Int
i1) = Vector Double
-> (Double, Int)
-> [Vector Double]
-> [Vector Double]
-> Int
-> (Double, Int)
forall (v :: * -> *).
(Vector v (Double, Double), Vector v Double) =>
Vector Double
-> (Double, Int)
-> [v Double]
-> [v Double]
-> Int
-> (Double, Int)
loop Vector Double
opti2 (Double
1Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/Double
0,-Int
1) [Vector Double]
s2 [Vector Double]
s1 Int
0
        (Double
d2,Int
i2) = Vector Double
-> (Double, Int)
-> [Vector Double]
-> [Vector Double]
-> Int
-> (Double, Int)
forall (v :: * -> *).
(Vector v (Double, Double), Vector v Double) =>
Vector Double
-> (Double, Int)
-> [v Double]
-> [v Double]
-> Int
-> (Double, Int)
loop Vector Double
opti1 (Double
1Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/Double
0,-Int
1) [Vector Double]
s1 [Vector Double]
s2 Int
0
    reverseAlign :: (Double, Int)
reverseAlign | Double
d1 Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
< Double
d2 = (Double
d1,Int
i1)
                 | Bool
otherwise = (Double
d2,-Int
i2)
      where
        (Double
d1,Int
i1) = Vector Double
-> (Double, Int)
-> [Vector Double]
-> [Vector Double]
-> Int
-> (Double, Int)
forall (v :: * -> *).
(Vector v (Double, Double), Vector v Double) =>
Vector Double
-> (Double, Int)
-> [v Double]
-> [v Double]
-> Int
-> (Double, Int)
loop Vector Double
opti2 (Double
1Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/Double
0,-Int
1) [Vector Double]
s2' [Vector Double]
s1 Int
0
        (Double
d2,Int
i2) = Vector Double
-> (Double, Int)
-> [Vector Double]
-> [Vector Double]
-> Int
-> (Double, Int)
forall (v :: * -> *).
(Vector v (Double, Double), Vector v Double) =>
Vector Double
-> (Double, Int)
-> [v Double]
-> [v Double]
-> Int
-> (Double, Int)
loop Vector Double
opti1 (Double
1Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/Double
0,-Int
1) [Vector Double]
s1 [Vector Double]
s2' Int
0

    loop :: Vector Double
-> (Double, Int)
-> [v Double]
-> [v Double]
-> Int
-> (Double, Int)
loop Vector Double
opti (Double
min',Int
i') [v Double]
a b :: [v Double]
b@(v Double
_:[v Double]
xs) !Int
i
        | Vector Double
opti Vector Double -> Int -> Double
forall a. Unbox a => Vector a -> Int -> a
U.! Int
i Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
>= Double
min' = (Double
min',Int
i')
        | Double
d Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
< Double
min' = Vector Double
-> (Double, Int)
-> [v Double]
-> [v Double]
-> Int
-> (Double, Int)
loop Vector Double
opti (Double
d,Int
i) [v Double]
a [v Double]
xs (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1)
        | Bool
otherwise = Vector Double
-> (Double, Int)
-> [v Double]
-> [v Double]
-> Int
-> (Double, Int)
loop Vector Double
opti (Double
min',Int
i') [v Double]
a [v Double]
xs (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1)
      where
        d :: Double
d = CombineFn
combFn Vector Double
sc Double -> Double -> Double
forall a. Num a => a -> a -> a
+ PenalFn
pFn Int
nGap Int
nMatch
        sc :: Vector Double
sc = [Double] -> Vector Double
forall a. Unbox a => [a] -> Vector a
U.fromList ([Double] -> Vector Double) -> [Double] -> Vector Double
forall a b. (a -> b) -> a -> b
$ (v Double -> v Double -> Double)
-> [v Double] -> [v Double] -> [Double]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith v Double -> v Double -> Double
DistanceFn
fn [v Double]
a [v Double]
b
        nMatch :: Int
nMatch = Vector Double -> Int
forall a. Unbox a => Vector a -> Int
U.length Vector Double
sc
        nGap :: Int
nGap = Int
n1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
n2 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
2 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
nMatch
    loop Vector Double
_ (Double, Int)
acc [v Double]
_ [v Double]
_ Int
_ = (Double, Int)
acc

    opti1 :: Vector Double
opti1 = Int -> Int -> Vector Double
optimalSc Int
n1 Int
n2
    opti2 :: Vector Double
opti2 = Int -> Int -> Vector Double
optimalSc Int
n2 Int
n1

    optimalSc :: Int -> Int -> Vector Double
optimalSc Int
x Int
y = [Double] -> Vector Double
forall a. Unbox a => [a] -> Vector a
U.fromList ([Double] -> Vector Double) -> [Double] -> Vector Double
forall a b. (a -> b) -> a -> b
$ (Double -> Double -> Double) -> [Double] -> [Double]
forall a. (a -> a -> a) -> [a] -> [a]
scanr1 Double -> Double -> Double
forall a. Ord a => a -> a -> a
f ([Double] -> [Double]) -> [Double] -> [Double]
forall a b. (a -> b) -> a -> b
$ Int -> [Double]
go Int
0
      where
        f :: a -> a -> a
f a
v a
min' = a -> a -> a
forall a. Ord a => a -> a -> a
min a
v a
min'
        go :: Int -> [Double]
go Int
i | Int
nM Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = []
             | Bool
otherwise = PenalFn
pFn Int
nG Int
nM Double -> [Double] -> [Double]
forall a. a -> [a] -> [a]
: Int -> [Double]
go (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1)
          where
            nM :: Int
nM = Int -> Int -> Int
forall a. Ord a => a -> a -> a
min Int
x (Int -> Int) -> Int -> Int
forall a b. (a -> b) -> a -> b
$ Int
y Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
i
            nG :: Int
nG = Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int -> Int
forall a. Num a => a -> a
abs (Int
x Int -> Int -> Int
forall a. Num a => a -> a -> a
- (Int
yInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
i))

    s1 :: [Vector Double]
s1 = Matrix Double -> [Vector Double]
forall a. Context a => Matrix a -> [Vector a]
M.toRows (Matrix Double -> [Vector Double])
-> (PWM -> Matrix Double) -> PWM -> [Vector Double]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PWM -> Matrix Double
_mat (PWM -> [Vector Double]) -> PWM -> [Vector Double]
forall a b. (a -> b) -> a -> b
$ PWM
m1
    s2 :: [Vector Double]
s2 = Matrix Double -> [Vector Double]
forall a. Context a => Matrix a -> [Vector a]
M.toRows (Matrix Double -> [Vector Double])
-> (PWM -> Matrix Double) -> PWM -> [Vector Double]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PWM -> Matrix Double
_mat (PWM -> [Vector Double]) -> PWM -> [Vector Double]
forall a b. (a -> b) -> a -> b
$ PWM
m2
    s2' :: [Vector Double]
s2' = Matrix Double -> [Vector Double]
forall a. Context a => Matrix a -> [Vector a]
M.toRows (Matrix Double -> [Vector Double])
-> (PWM -> Matrix Double) -> PWM -> [Vector Double]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PWM -> Matrix Double
_mat (PWM -> [Vector Double]) -> PWM -> [Vector Double]
forall a b. (a -> b) -> a -> b
$ PWM
m2'
    m2' :: PWM
m2' = PWM -> PWM
rcPWM PWM
m2
    n1 :: Int
n1 = [Vector Double] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Vector Double]
s1
    n2 :: Int
n2 = [Vector Double] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Vector Double]
s2
{-# INLINE alignmentBy #-}