{-# LANGUAGE LambdaCase #-}
module Overload.Diff where

import Data.Either (partitionEithers)

import Overload.TypeTree

data DiffStep = GoLeft | GoRight deriving (Eq, Ord, Read, Show)
type Diff = [DiffStep]

wholeTreeDiff :: TypeTree a -> [Diff]
wholeTreeDiff (Var _) = []
wholeTreeDiff (Concrete _) = [[]]
wholeTreeDiff (App t1 t2) =
    fmap (GoLeft :) (wholeTreeDiff t1) ++ fmap (GoRight :) (wholeTreeDiff t2)

diff :: TypeTree a -> TypeTree b -> [Diff]
diff (Var _) _ = []
diff _ (Var _) = []
diff (Concrete t1) (Concrete t2) | t1 /= t2  = [[]]
                                 | otherwise = []
diff (Concrete _) (App _ _) = [[]]
diff (App t1 t2) (App t3 t4) = fmap (GoLeft :) (diff t1 t3) ++ fmap (GoRight :) (diff t2 t4)
diff (App t1 t2) (Concrete _) = [] : wholeTreeDiff (App t1 t2)

treeFromDiff :: Diff -> TypeTree a -> TypeTree (Maybe a)
treeFromDiff [] (Var n) = Var (Just n)
treeFromDiff [] (Concrete t) = Concrete t
treeFromDiff [] (App _ _) = App (Var Nothing) (Var Nothing)
treeFromDiff (GoLeft : ds) (App t _) = App (treeFromDiff ds t) (Var Nothing)
treeFromDiff (GoRight : ds) (App _ t) = App (Var Nothing) (treeFromDiff ds t)
treeFromDiff _ _ = error "Invalid diff for type tree"

diffToEither :: Diff -> Either Diff Diff
diffToEither (GoLeft : ds) = Left ds
diffToEither (GoRight : ds) = Right ds
diffToEither _ = error "Can't convert an empty diff into an Either"

treeFromDiffs :: [Diff] -> TypeTree a -> TypeTree (Maybe a)
treeFromDiffs [] _ = Var Nothing
treeFromDiffs ds (Var n) | all null ds = Var (Just n)
                         | otherwise   = error "Diff is trying to go through a leaf"
treeFromDiffs ds (Concrete t) | all null ds = Concrete t
                              | otherwise   = error "Diff is trying to go through a leaf"
treeFromDiffs ds (App t1 t2) =
    App (treeFromDiffs lefts t1) (treeFromDiffs rights t2)
    where (lefts, rights) = partitionEithers (fmap diffToEither (filter (not . null) ds))

deciders :: TypeTree a -> [TypeTree b] -> [TypeTree (Maybe a)]
deciders t [] = fmap (`treeFromDiff` t) (wholeTreeDiff t)
deciders t ts = fmap (`treeFromDiffs` t) (sequence (fmap (diff t) ts))