module Matrix
    ( removeChar
    , getClusterFrequencyMap
    , getBlockMap
    , joinBlockMaps
    , getBlosum
    ) where
import Data.Maybe
import Data.Tuple
import qualified Data.Map.Strict as Map
import qualified Data.Sequence as Seq
import qualified Data.Foldable as F
import Control.Applicative
import Control.Lens
import Data.Fasta.Text (FastaSequence)
import Types
import Utility
zipSize :: Frequency -> [AA] -> [(AA, Frequency)]
zipSize x = flip zip [x,x..]
zipPosition :: [a] -> [(Position, a)]
zipPosition = zip [1..]
getClusterFrequencyMap :: Seq.Seq FastaSequence -> ClusterFrequencyMap
getClusterFrequencyMap xs = ClusterFrequencyMap
                          . Map.map summarize
                          . Map.fromListWith (Seq.><)
                          . concat
                          . F.toList
                          . fmap positionFrequencies
                          $ xs
  where
    summarize           = Seq.fromList . zipSize clusterSize . F.toList
    positionFrequencies = zipPosition
                        . map Seq.singleton
                        . getSeq
    clusterSize         = Frequency . fromIntegral . Seq.length $ xs
removeChar :: Bool -> Maybe [AA] -> AAMap -> AAMap
removeChar _ Nothing = id
removeChar allFlag (Just badChars) = AAMap
                                   . filterBadKey
                                   . Map.map filterBadKey
                                   . removeAll allFlag
                                   . unAAMap
  where
    removeAll False x = x
    removeAll True x  =
        if or
         . map
           ( \bad -> Map.member bad x
                  || (or . Map.elems . Map.map (Map.member bad) $ x)
           )
         $ badChars
            then Map.empty
            else x
    filterBadKey    = Map.filterWithKey (\k _ -> notElem k badChars)
collectPairs :: (Ord a, Num b, Fractional b) => Seq.Seq (Seq.Seq (a, b))
                                             -> Seq.Seq (a, (a, b))
                                             -> Seq.Seq (a, (a, b))
collectPairs (Seq.null -> True) !ys         = ys
collectPairs (Seq.viewl -> x Seq.:< xs) !ys =
    collectPairs xs $ comparisons Seq.>< flippedComparisons Seq.>< ys
  where
    flippedComparisons = Seq.filter (\a -> fst a /= (fst . snd $ a))
                       . fmap (\ (!a, (!b, !c)) -> (b, (a, c)))
                       $ comparisons
    comparisons        = F.asum . fmap (pairs x) $ xs
    pairs as bs        = (\ (!a, !b) (!c, !d)
                         -> (a, (c, (1 / (b * d))))
                         )
                     <$> as
                     <*> bs
toAAMap :: Seq.Seq (Seq.Seq (AA, Frequency)) -> AAMap
toAAMap = AAMap
        . Map.fromListWith (Map.unionWith (+))
        . F.toList
        . fmap (over _2 (uncurry Map.singleton))
        . flip collectPairs Seq.empty
getBlockMap :: Bool -> Maybe [AA] -> [ClusterFrequencyMap] -> BlockMap
getBlockMap allFlag badChars = BlockMap
                             . mconcat
                             . Map.elems
                             . Map.map (removeChar allFlag badChars . toAAMap)
                             . Map.unionsWith (Seq.><)
                             . map (Map.map Seq.singleton) 
                             . map unClusterFrequencyMap
joinBlockMaps :: [BlockMap] -> FrequencyMap
joinBlockMaps = FrequencyMap . mconcat . map unBlockMap
getBlosum :: FrequencyMap -> Blosum
getBlosum (FrequencyMap (AAMap frequencyMap)) =
    Blosum
        . Map.mapWithKey (\k -> Map.mapWithKey (\l _ -> blosum k l))
        $ frequencyMap
  where
    blosum x y = BlosumVal . round $ 2 * logBase 2 (q x y / e x y)
    e x y      = if x == y then p x * p y else 2 * p x * p y
    p x        = q x x
               + (sum (map (q x) . filter (/= x) . Map.keys $ frequencyMap) / 2)
    q x y      = (\(Frequency a) -> a)
               $ (lookZero y (lookMap x $ frequencyMap))
               / qDenom
    qDenom = (/ 2)
           $ (sumMap . Map.map sumMap $ frequencyMap)
           + ( sumMap
             . Map.mapWithKey
               (\k1 -> sumMap . Map.filterWithKey (\k2 _ -> k1 == k2))
             $ frequencyMap
             )
    lookZero k = fromMaybe 0 . Map.lookup k
    lookMap k  = fromMaybe Map.empty . Map.lookup k