module Biobase.Secondary.Diagrams where
import           Control.Applicative
import           Control.Arrow
import           Control.Lens
import           Data.Aeson
import           Data.Binary
import           Data.List ((\\))
import           Data.List (sort,groupBy,sortBy,intersperse)
import           Data.List.Split (splitOn)
import           Data.Serialize
import           Data.Tuple.Select (sel1,sel2)
import           Data.Tuple (swap)
import           Data.Vector.Binary
import           Data.Vector.Serialize
import           GHC.Generics
import qualified Data.Vector.Unboxed as VU
import           Text.Printf
import           Control.DeepSeq
import           Biobase.Primary.Nuc
import           Biobase.Secondary.Basepair
newtype D1Secondary = D1S {unD1S :: VU.Vector Int}
  deriving (Read,Show,Eq,Generic,NFData)
instance Binary    D1Secondary
instance Serialize D1Secondary
instance FromJSON  D1Secondary
instance ToJSON    D1Secondary
newtype D2Secondary = D2S {unD2S :: VU.Vector ( (Int,Edge,CTisomerism), (Int,Edge,CTisomerism) )}
  deriving (Read,Show,Eq,Generic)
instance Binary    D2Secondary
instance Serialize D2Secondary
instance FromJSON  D2Secondary
instance ToJSON    D2Secondary
class MkD1Secondary a where
  mkD1S :: a -> D1Secondary
  fromD1S :: D1Secondary -> a
class MkD2Secondary a where
  mkD2S :: a -> D2Secondary
  fromD2S :: D2Secondary -> a
data SSTree idx a = SSTree   idx a [SSTree idx a]
                  | SSExtern Int a [SSTree idx a]
  deriving (Read,Show,Eq,Generic)
d1sTree :: D1Secondary -> SSTree PairIdx ()
d1sTree s = ext $ sort ps where
  (len,ps) = fromD1S s
  ext [] = SSExtern len () []
  ext xs = SSExtern len () . map tree $ groupBy (\l r -> snd l > fst r) xs 
  tree [ij]    = SSTree ij () []
  tree (ij:xs) = SSTree ij () . map tree $ groupBy (\l r -> snd l > fst r) xs
d2sTree :: D2Secondary -> SSTree ExtPairIdx ()
d2sTree s = ext $ sortBy d2Compare ps where
  (len,ps) = fromD2S s
  ext [] = SSExtern len () []
  ext xs = SSExtern len () . map tree . groupBy d2Grouping $ xs
  tree [ij]    = SSTree ij () []
  tree (ij:xs) = SSTree ij () . map tree . groupBy d2Grouping $ xs
d2Compare ((i,j),_) ((k,l),_)
  | i==k = compare l j
  | j==l = compare i k
  | otherwise = compare (i,j) (k,l)
d2Grouping ((i,j),_) ((k,l),_) = i<=k && j>=l
instance MkD1Secondary D2Secondary where
  mkD1S = fromD2S
  fromD1S = mkD2S
instance MkD1Secondary (Int,[PairIdx]) where
  mkD1S (len,ps) = let xs = concatMap (\ij -> [ij,swap ij]) ps
                   in D1S $ VU.replicate len (-1) VU.// xs
  fromD1S (D1S s) = (VU.length s, filter (\(i,j) -> i<j && j>=0) . zip [0..] . VU.toList $ s)
instance MkD2Secondary D1Secondary where
  mkD2S = D2S . VU.map (\k -> ((k,W,Cis),(-1,W,Cis))) . unD1S
  fromD2S (D2S xs) = D1S . VU.map (sel1 . sel1) $ xs
instance MkD2Secondary (Int,[ExtPairIdx]) where
  mkD2S (len,ps) = let xs = concatMap (\((i,j),(ct,e1,e2)) ->
                                          [ (i, (j,e1,ct))
                                          , (j, (i,e2,ct))
                                          ]) ps
                       f (x,y) z = if sel1 x == -1 then (z,y) else (x,z)
                   in D2S $ VU.accum f (VU.replicate len ((-1,W,Cis),(-1,W,Cis))) xs
  fromD2S (D2S s) = ( VU.length s
                    , let (xs,ys) = unzip . VU.toList $ s
                          g i j = let z = s VU.! i in if sel1 (sel1 z) == j then sel2 (sel1 z) else sel2 (sel2 z)
                          f (i,(j,eI,ct)) = ((i,j),(ct,eI,g j i))
                      in
                      map f . filter (\(i,(j,_,_)) -> i<j && j>=0) $ zip [0..] xs ++ zip [0..] ys
                    )
instance MkD1Secondary ([String],String) where
  mkD1S (dict,xs) = mkD1S (length xs,ps) where
    ps :: [(Int,Int)]
    ps = unsafeDotBracket2pairlist dict xs
  fromD1S (D1S s) = (["()"], zipWith f [0..] $ VU.toList s) where
    f k (-1) = '.'
    f k p
      | k>p = ')'
      | otherwise = '('
instance MkD1Secondary ([String],VU.Vector Char) where
  mkD1S (dict,xs) = mkD1S (dict, VU.toList xs)
  fromD1S s = let (dict,res) = fromD1S s in (dict,VU.fromList res)
instance MkD1Secondary String where
  mkD1S xs = mkD1S (["()" ::String],xs)
  fromD1S s = let (_::[String],res) = fromD1S s in res
instance MkD1Secondary (VU.Vector Char) where
  mkD1S xs = mkD1S (["()" ::String],xs)
  fromD1S s = let (_::[String],res::VU.Vector Char) = fromD1S s in res
isCanonicalStructure :: String -> Bool
isCanonicalStructure = all (`elem` "().")
isConstraintStructure :: String -> Bool
isConstraintStructure = all (`elem` "().<>{}|")
structures :: Iso' String [String]
structures = iso (splitOn "&") (concat . intersperse "&")
foldStructure :: Prism' String String
foldStructure = prism id to where
  to s = case s^.structures of
           [t] -> Right t
           _   -> Left  s
cofoldStructure :: Prism' String (String,String)
cofoldStructure = prism from to where
  from (l,r) = l ++ '&' : r
  to   s     = case s^.structures of
                 [l,r] -> Right (l,r)
                 _     -> Left  s
unsafeDotBracket2pairlist :: [String] -> String -> [(Int,Int)]
unsafeDotBracket2pairlist dict xs = sort . concatMap (f xs) $ dict where
  f xs [l,r] = g 0 [] . map (\x -> if x `elem` [l,r] then x else '.') $ xs where
    g :: Int -> [Int] -> String -> [(Int,Int)]
    g _ st [] = []
    g k st ('.':xs) = g (k+1) st xs
    g k sst (x:xs)
      | l==x = g (k+1) (k:sst) xs
    g k (s:st) (x:xs)
      | r==x = (s,k) : g (k+1) st xs
    g a b c = error $ show (a,b,c)
dotBracket2pairlist :: [String] -> String -> Either String ( [(Int,Int)] )
dotBracket2pairlist dict str = fmap (sort . concat) . sequence . map (f str) $ dict where
  f ys [l,r] = g 0 [] . map (\x -> if x `elem` [l,r] then x else '.') $ ys where
    g :: Int -> [Int] -> String -> Either String ( [(Int,Int)] )
    g _ [] [] = pure []
    g k st ('.':xs) = g (k+1) st xs
    g k st (x:xs) | l==x = g (k+1) (k:st) xs
    g k (s:st) (x:xs) | r==x = ((s,k):) <$> g (k+1) st xs
    g k [] xs = fail $ printf "too many closing brackets at position %d: '%s' (dot-bracket: %s)" k xs str
    g k st [] = fail $ printf "too many opening brackets, opening bracket(s) at: %s (dot-bracket: %s)" (show $ reverse st) str
    g a b c   = fail $ printf "unspecified error: %s (dot-bracket: %s)" (show (a,b,c)) str
  f xs lr@(_:_:_:_) = fail $ printf "unsound dictionary: %s (dot-bracket: %s)" lr str
  f xs lr     = fail $ printf "unspecified error: dict: %s, input: %s (dot-bracket: %s)" lr xs str
viennaStringDistance :: Bool -> Bool -> String -> String -> (String,Int)
viennaStringDistance sPairs tPairs s t = (t,length $ ss++tt) where
  s' = either error id . dotBracket2pairlist ["()"] $ s
  t' = either error id . dotBracket2pairlist ["()"] $ t
  ss = if sPairs then s' \\ t' else []
  tt = if tPairs then t' \\ s' else []
d1Distance :: D1Secondary -> D1Secondary -> Int
d1Distance (D1S x) (D1S y)
  | otherwise = (`div` 2) . VU.sum $ VU.zipWith chk (x VU.++ xx) (y VU.++ yy)
  where xx = VU.replicate (VU.length y - VU.length x) (-2)
        yy = VU.replicate (VU.length x - VU.length y) (-2)
        chk i j | i==j             = 0
                | i <  0 && j <  0 = 0
                | i >= 0 && j >= 0 = 2
                | otherwise        = 1
        {-# Inline chk #-}
{-# NoInline d1Distance #-}