{-# LANGUAGE TupleSections #-}
module YesodDsl.ClassImplementer (implementClasses) where
import Data.List
import YesodDsl.AST
import Data.Maybe
import Data.Generics
import Data.Generics.Uniplate.Data
import qualified Data.Map as Map
import qualified Data.List as L

lookupField' :: Module -> EntityName -> FieldName -> Maybe Field
lookupField' m en fn = listToMaybe [ f | e <- modEntities m,
                                    f <- entityFields e,
                                    entityName e == en,
                                    fieldName f == fn ] 


implementClasses :: Module -> Module
implementClasses m = let m' = m {
        modEntities  = [ implInEntity m (modClasses m) e | e <- modEntities m ]
    } in m' {
        modRoutes = everywhere ((mkT $ trStmt m) . (mkT $ trSq m')) $ modRoutes m'
    }

trStmt :: Module -> Stmt -> Stmt
trStmt m s = case s of
    Update er@(Left en) fr (Just frm) -> Update er fr (Just $ concatMap (f en) frm)
    Insert er@(Left en) (Just (mvn1, frm)) mvn2 -> Insert er (Just (mvn1, concatMap (f en) frm)) mvn2
                   
    x -> x
    where
        f en x@(pn,RequestField fn,mfn) = if pn == fn 
                then fromMaybe [x] $ do
                    f <- lookupField' m en pn
                    case fieldContent f of
                        EntityField name -> do
                            c <- classLookup (modClasses m) name

                            Just $ [ 
                                    (lowerFirst (entityName e) ++ upperFirst pn,
                                     RequestField (lowerFirst (entityName e) ++ upperFirst pn),
                                     mfn) 
                                    | e <- modEntities m, 
                                      className c `elem` entityInstances e
                                ]
                        _ -> Nothing
                else [x]
        f _ x = [x]

            

trSq :: Module -> SelectQuery -> SelectQuery
trSq m sq = sq {
        sqFields = concatMap trSelectField $ sqFields sq,
        sqJoins =  map snd newJoins,
        sqWhere = everywhere (mkT trExpr) $ sqWhere sq
    }
    where
        vnMap :: [(VariableName, (VariableName, Maybe EntityName))]
        vnMap = mapMaybe fst newJoins
        aliases = map (\(er,vn) -> (entityRefName er, vn)) $ sqFrom sq : [ (joinEntity j,joinAlias j) | j <- sqJoins sq ]
        newAliases vn = Map.findWithDefault [(vn,Nothing)] vn $ Map.fromListWith (++) $ [ (s,[d]) | (s,d) <- vnMap ]
        allAliases = aliases ++ catMaybes [ men >>= Just . (,vn) | (_,(vn,men)) <- vnMap ]
        newJoins = concatMap expandJoin $ sqJoins sq
        expandJoin j = fromMaybe [(Nothing, j)] $ do
            c <- classLookup (modClasses m) $ entityRefName $ joinEntity j
            Just $ [ 
                    let a = joinAlias j
                        a' = joinAlias j ++ "_" ++ entityName e
                    in (Just (a, (a', Just $entityName e)), j {
                        joinAlias = a',
                        joinEntity = (Left $ entityName e),
                        joinExpr = joinExpr j >>= Just . (everywhere $ (mkT $ trDropInvalidExprs) . (mkT $ trClassField (entityName e)) . (mkT $ trVar a a'))
                    }) | e <- modEntities m, className c `elem` entityInstances e
                ]
        
        trClassField en fr = case fr of
            SqlField v'@(Var vn _ _) fn -> fromMaybe fr $ do
                (en',_) <- L.find ((==vn) . snd) aliases
                f <- lookupField' m en' (lowerFirst en ++ upperFirst fn)
                Just $ SqlField v' $ fieldName f
            _ -> fr    
                
        trVar srcVn dstVn fr = case fr of
            SqlField (Var vn _ _) fn -> if vn == srcVn then SqlField (Var dstVn (Left "") False) fn else fr
            SqlId (Var vn _ _) -> if vn == srcVn then SqlId (Var dstVn (Left "") False) else fr
            _ -> fr
        aliasName :: FieldName -> Maybe VariableName -> Maybe EntityName -> Maybe VariableName    

        aliasName fn man (Just en) = Just $ fromMaybe (fn ++ en) $ man >>= Just . (++en)
        aliasName _ man Nothing = man
    
        trSelectField sf = 
            case sf of
                SelectAllFields (Var vn _ _) -> [
                        SelectAllFields (Var vn' (Left "") False)
                        | (vn',_) <- newAliases vn
                    ]
                SelectField (Var vn _ _) fn man -> [
                        SelectField (Var vn' (Left "") False) fn $ aliasName fn man men
                        | (vn',men) <- newAliases vn,
                          validField (vn',fn)
                    ]
                SelectIdField (Var vn _ _) man -> [
                        SelectIdField (Var vn' (Left "") False) $ aliasName "id" man men
                        | (vn',men) <- newAliases vn
                    ]
                SelectExpr _ _ -> [sf]  

        trExpr e = 
            let r = catMaybes [
                        let e' = everywhere (mkT $ trVar s d) e
                        in if e' /= e && validExpr e' then Just e' else Nothing
                        | (s,(d,_)) <- vnMap
                    ] 
                in if null r then eelse foldl1 mkOrExpr r
        mkOrExpr e1 e2 = BinOpExpr e1 Or e2              
        trueExpr = let c = (FieldExpr (Const (BoolValue True))) in BinOpExpr c Eq c
        trDropInvalidExprs e = 
            let  
                me = case e of
                    BinOpExpr e1 And e2 -> Just (e1,e2)
                    BinOpExpr e1 Or e2 -> Just (e1,e2)
                    _ -> Nothing
            in case me >>= \(e1,e2) -> Just  (validExpr e1, validExpr e2, e1, e2) of
                    Just (True, True, _, _) -> e
                    Just (False, True, _, e2) -> e2
                    Just (True, False, e1, _) -> e1
                    Just (False, False, _, _) -> trueExpr
                    Nothing -> e

        validExpr e = let fs = [ (vn,fn) | SqlField (Var vn _ _) fn <- universeBi e ]
                      in all validField fs   
        validField (vn,fn) = fromMaybe False $ do
            (en,_) <- L.find ((==vn) . snd) allAliases
            _ <- lookupField' m en fn
            Just True

classLookup :: [Class] -> ClassName -> Maybe Class
classLookup classes name =  find (\i -> name == className i) classes


expandClassField :: Module -> ClassName -> Entity ->  Field -> [Field]
expandClassField m cn e f@(Field _ _ _ (EntityField iName) opts _) 
    | not $ fieldOptional f = error $ show (entityLoc e) ++ ": non-maybe reference to class not allowed"
    | otherwise = [ mkField re | re <- modEntities m,  
                                 iName `elem` (entityInstances re) ]
    where mkField re = Field {
            fieldLoc = fieldLoc f,
            fieldOptional = True,
            fieldName = lowerFirst (entityName re) ++ upperFirst (fieldName f),
            fieldContent = EntityField (entityName re),
            fieldOptions = opts,
            fieldClassName = Just (cn, fieldName f)
        } 
expandClassField _ _ _ _ = []

expandClassRefFields :: Module -> Entity -> Field -> [Field]
expandClassRefFields m e f = expand (fieldContent f)
    where       
        expand (EntityField "ClassInstance") = [ 
                f { 
                    fieldContent = EntityField (entityName e)
                }
            ]
        expand (EntityField name) = case classLookup (modClasses m) name of
            Just _ -> expandClassField m name e f 
            Nothing -> [f]
        expand _ = [f]                           
            
expandClassRefUniques :: Module -> Entity -> Unique -> [Unique]           
expandClassRefUniques m e u = expand [u] cFields
    where
        expand us (f:fs) = expand (concatMap (expandField f) us) fs
        expand us [] = us
        expandField f u 
            | fieldName f `elem` uniqueFields u = [ 
                    u { 
                        uniqueName = uniqueName u ++ fieldEntityName f',
                        uniqueFields = (uniqueFields u L.\\ [ fieldName f ] ) ++ [ fieldName f' ] 
                    }
                    | f' <- expandClassRefFields m e f
                ]
            | otherwise = [u ]fieldEntityName f = case fieldContent f of
            EntityField en -> en
            _ -> ""
        cFields = [ f | fn <- uniqueFields u, 
                         f <- entityFields e,
                         fieldName f == fn,
                         isClassField m f ]
         
isClassField :: Module -> Field -> Bool
isClassField m (Field _ _ _ (EntityField iName) _ _) = iName `elem` (map className $ modClasses m)
isClassField _ _ = False
implInEntity :: Module -> [Class] -> Entity -> Entity
implInEntity m classes' e = e { 
        entityFields  = concatMap (expandClassRefFields m e) $ 
                            entityFields e ++ extraFields,
        entityClassFields = filter (isClassField m) $ entityFields e,
        entityUniques = concatMap (expandClassRefUniques m e) $ entityUniques e ++ (map addEntityNameToUnique $ concatMap classUniques validClasses)
    }
    where
        instances = entityInstances e
        classes = sortBy (\c1 c2 -> maybeCompare (elemIndex (className c1) instances) 
                                                 (elemIndex (className c2) instances))
                         classes'
        maybeCompare (Just a1) (Just a2) = compare a1 a2
        maybeCompare (Just _) Nothing = Prelude.LT
        maybeCompare Nothing (Just _) = Prelude.GT
        maybeCompare Nothing Nothing = Prelude.EQ
        validClasses = mapMaybe (classLookup classes) $ entityInstances e
        extraFields = concatMap classFields validClasses
        addEntityNameToUnique (Unique name fields) = Unique (entityName e ++ name) fields