{-# LANGUAGE TypeSynonymInstances, FlexibleInstances #-}
module CSPM.Evaluator.PatBind where

import CSPM.DataStructures.Literals
import CSPM.DataStructures.Names
import CSPM.DataStructures.Syntax
import CSPM.Evaluator.Values
import CSPM.Evaluator.ValueSet
import Util.Annotated

-- Bind :: Pattern, value -> (Matches Pattern, Values to Bind
class Bindable a where
    bind :: a -> Value -> (Bool, [(Name, Value)])

instance Bindable a => Bindable (Annotated b a) where
    bind (An _ _ a) v = bind a v

instance Bindable (Pat Name) where
    -- We can decompose any PConcat pattern into three patterns representing:
    -- Begining (of the form PList), middle (either PWildcard or PVar)
    -- and end (of the form PList), With both begining and end possible empty
    bind (PCompList ps Nothing _) (VList xs) | length ps == length xs = 
        bindAll ps xs
    -- By desugaring the middle is not a PConcat or a PList
    bind (PCompList starts (Just (middle, ends)) _) (VList xs) =
        -- Only match if the list contains sufficient items
        if not (atLeastLength (length starts + length ends) xs) then 
            (False, [])
        else
            let
                (b1, nvs1) = bindAll starts xsStart
                (b2, nvs2) = bindAll ends xsEnd
                (b3, nvs3) = bind middle (VList xsMiddle)
            in (b1 && b2 && b3, nvs1++nvs2++nvs3)
        where
            atLeastLength 0 _ = True
            atLeastLength _ [] = False
            atLeastLength n (x:xs) = atLeastLength (n-1) xs
            (xsStart, rest) = splitAt (length starts) xs
            (xsMiddle, xsEnd) = 
                if length ends == 0 then (rest, [])
                else splitAt (length rest - length ends) rest
    bind (PCompDot ps _) (VDot vs) =
        let 
            -- Matches a compiled dot pattern, given a list of patterns for
            -- the fields and the values that each field takes.
            matchCompDot :: [Pat Name] -> [Value] -> (Bool, [(Name, Value)])
            matchCompDot [] [] = (True, [])
            matchCompDot (PVar n:ps) (VDot (VDataType n':vfs):vs2) | isNameDataConstructor n = 
                -- In this case, we are matching within a subfield of the
                -- current field. Therefore, add all the values that this
                -- subfield has.
                if n /= n' then (False, []) 
                else matchCompDot ps (vfs++vs2)
            matchCompDot (PVar n:ps) (VDot (VChannel n':vfs):vs2) | isNameDataConstructor n = 
                if n /= n' then (False, []) 
                else matchCompDot ps (vfs++vs2)
            matchCompDot [p] [v] = bind p v
            matchCompDot [p] vs = bind p (VDot vs)
            matchCompDot (p:ps) (v:vs) = 
                let
                    (b1, nvs1) = bind p v
                    (b2, nvs2) = matchCompDot ps vs
                in (b1 && b2, nvs1++nvs2)
            matchCompDot _ _ = (False, [])
        in matchCompDot (map unAnnotate ps) vs
    bind (PDoublePattern p1 p2) v =
        let
            (m1, b1) = bind p1 v
            (m2, b2) = bind p2 v
        in (m1 && m2, b1++b2)
    bind (PLit (Int i1)) (VInt i2) | i1 == i2 = (True, [])
    bind (PLit (Bool b1)) (VBool b2) | b1 == b2 = (True, [])
    bind (PLit (Char c1)) (VChar c2) | c1 == c2 = (True, [])
    bind (PSet []) (VSet s) | empty s = (True, [])
    bind (PSet [p]) (VSet s) = 
        case singletonValue s of
            Just v  -> bind p v
            Nothing -> (False, [])
    bind (PTuple ps) (VTuple vs) = do
        -- NB: len ps == len vs by typechecker
        bindAll ps (elems vs)
    bind (PVar n) v | isNameDataConstructor n = 
        case v of
            VChannel n' -> (n == n', [])
            VDataType n' -> (n == n', [])
            -- We have to allow these to enable patterns such as f(J) where
            -- J has arity 0.
            VDot [VChannel n'] -> (n == n', [])
            VDot [VDataType n'] -> (n == n', [])
            -- We have to allow patterns like `match` J against X.0 in case
            -- these are in the same data type and a function has two clauses,
            -- one for each of these cases. e.g.
            -- f(J) = X
            -- f(X.0) = X
            -- as in one clause we match J against VDot [X,0]
            _ -> (False, [])
    bind (PVar n) v = (True, [(n, v)])
    bind PWildCard v = (True, [])
    bind _ _ = (False, [])

bindAll :: Bindable a => [a] -> [Value] -> (Bool, [(Name, Value)])
bindAll ps xs =
    let
        rs = zipWith bind ps xs
    in (and (map fst rs), concat (map snd rs))