module HarmTrace.Matching.Alignment ( alignChordLab, pPrintV, getAlignDist
, getHAnDist, alignHAnChord
) where
import HarmTrace.Base.MusicRep
import HarmTrace.Matching.SimpleChord
import HarmTrace.Matching.HChord
import HarmTrace.Matching.Sim
import HarmTrace.HAnTree.HAn
import HarmTrace.HAnTree.Tree
import Prelude hiding (map, length, head, last, mapM_, max)
import Data.Vector hiding ((!), (++))
import qualified Data.List as L
getAlignDist :: Key -> Key -> [ChordLabel] -> [ChordLabel] -> Float
getAlignDist ka kb ta tb = let (_match, dist, _tab) = alignChordLab ka kb ta tb
in dist
alignChordLab :: Key -> Key -> [ChordLabel] -> [ChordLabel]
-> ([SimChord], Float, Vector (Vector Int))
alignChordLab ka kb ta tb = (fst $ matchToSeq match ta' tb', dis, tab) where
(match, weight, tab) =
align (2) ta' tb'
dis = fromIntegral (weight * weight)
/ fromIntegral (maxSim ta' * maxSim tb')
ta' = L.concatMap (toSimChords . toChordDegree ka) ta
tb' = L.concatMap (toSimChords . toChordDegree kb) tb
getHAnDist :: Tree HAn -> Tree HAn -> Float
getHAnDist ta tb = let (_match, dist, _tab) = alignHAnChord ta tb in dist
alignHAnChord :: Tree HAn -> Tree HAn -> ([HChord], Float, Vector (Vector Int))
alignHAnChord ta tb =
(fst $ matchToSeq match ta' tb', dis, tab) where
(match, weight, tab) = align (2) ta' tb'
dis = fromIntegral (weight * weight)
/ fromIntegral (maxSim ta' * maxSim tb')
ta' = toHChords ta
tb' = toHChords tb
align :: Sim a => Int -> [a] -> [a] -> ([(Int,Int)], Int, Vector (Vector Int))
align _ _ [] = ([],0,empty)
align _ [] _ = ([],0,empty)
align inDel a b = (cm, getDownRight t,t) where
t = wbMatchF inDel a b
cm = toList (collectMatch t)
wbMatchF :: Sim a => Int -> [a] -> [a] -> Vector (Vector Int)
wbMatchF _ _ [] = empty
wbMatchF _ [] _ = empty
wbMatchF inDel a' b' = m where
a = fromList a'
b = fromList b'
match, fill :: Int -> Int -> Int
match i j = sim (a ! i) (b ! j)
fill 0 0 = max (match 0 0) 0
fill 0 j = max0 (((m ! 0 ) !(j1)) + inDel) (match 0 j)
fill i 0 = max0 (((m !(i1)) ! 0 ) + inDel) (match i 0)
fill i j = max3 (((m !(i1)) ! j ) + inDel)
(((m !(i1)) !(j1)) + match i j)
(((m ! i) !(j1)) + inDel)
m = generate (length a) (generate (length b) . fill)
collectMatch :: Vector (Vector Int) -> Vector (Int,Int)
collectMatch a = fromList $ collect a (length a 1, length (head a) 1) []
collect :: (Ord b, Num b) => Vector (Vector b) -> (Int, Int) -> [(Int, Int)]
-> [(Int, Int)]
collect a c@(0,0) m = if (a!0)!0 > 0 then c : m else m
collect a c@(i,0) m = if (a!i)!0 > (a!(i1))! 0
then c : m else collect a (i1,0) m
collect a c@(0,j) m = if (a!0)!j > (a!0 )!(j1)
then c : m else collect a (0,j1) m
collect a c@(i,j) m
| (a ! i) ! j > snd o = collect a (fst o) (c : m)
| otherwise = collect a (fst o) m where
o = realMax3 ((i1,j) , (a !(i1)) ! j )
((i1,j1), (a !(i1)) !(j1))
((i,j1) , (a ! i ) !(j1))
realMax3 :: (Ord a) => (t, a) -> (t, a) -> (t, a) -> (t, a)
realMax3 w nw n = maxByWeight nw (maxByWeight w n) where
maxByWeight :: Ord a => (t,a) -> (t,a) -> (t,a)
maxByWeight a@(_,wa) b@(_,wb) = if wa > wb then a else b
matchToSeq :: [(Int,Int)] -> [a] -> [a] -> ([a],[a])
matchToSeq mat aOrg bOrg = (f aMat aOrg, f bMat bOrg) where
f m o = fst . L.unzip $ L.filter (\(_,x) -> x `L.elem` m) (L.zip o [0..])
(aMat, bMat) = L.unzip mat
(!) :: Vector a -> Int -> a
(!) = unsafeIndex
max3 :: (Ord a, Num a) => a -> a -> a -> a
max3 a b c = max a (max0 b c)
max0 :: (Ord a, Num a) => a -> a -> a
max0 a b = max a (max b 0)
max :: (Ord a, Num a) => a -> a -> a
max x y = if x <= y then y else x
getDownRight :: Vector (Vector a) -> a
getDownRight n = last (last n)
pPrintV :: Show a => Vector (Vector a) -> IO ()
pPrintV = mapM_ printLn where
printLn :: Show a => Vector a -> IO()
printLn v = do mapM_ (\x -> putStr (show x ++ " ")) v ; putChar '\n'