-----------------------------------------------------------------------------
-- |
-- Module      :  Database.TxtSushi.Relational
-- Copyright   :  (c) Keith Sheppard 2009-2010
-- License     :  BSD3
-- Maintainer  :  keithshep@gmail.com
-- Stability   :  experimental
-- Portability :  portable
--
-- Some functions for joining lists
--
-----------------------------------------------------------------------------

module Database.TxtSushi.Relational (
    joinTables,
    crossJoinTables,
    joinPresortedTables) where

import Data.List
import Data.Function

-- | join together two tables on the given column index pairs
joinTables :: (Ord o) =>
    (a -> o)
    -> [a]
    -> (b -> o)
    -> [b]
    -> [(a, b)]
joinTables joinOrdFunc1 table1 joinOrdFunc2 table2 =
    let
        sortedTable1 = sortBy (compare `on` joinOrdFunc1) table1
        sortedTable2 = sortBy (compare `on` joinOrdFunc2) table2
    in
        joinPresortedTables joinOrdFunc1 sortedTable1 Nothing joinOrdFunc2 sortedTable2 Nothing

-- | join together two tables that are presorted on the given column index pairs
joinPresortedTables :: (Ord o) =>
    (a -> o)
    -> [a]
    -> Maybe a
    -> (b -> o)
    -> [b]
    -> Maybe b
    -> [(a, b)]
joinPresortedTables joinOrdFunc1 sortedTable1 maybeNull1 joinOrdFunc2 sortedTable2 maybeNull2 =
    let
        tableGroups1 = groupBy rowEq1 sortedTable1
        tableGroups2 = groupBy rowEq2 sortedTable2
    in
        joinGroupedTables joinOrdFunc1 tableGroups1 maybeNull1 joinOrdFunc2 tableGroups2 maybeNull2
    where
        rowEq1 x y = (compare `on` joinOrdFunc1) x y == EQ
        rowEq2 x y = (compare `on` joinOrdFunc2) x y == EQ

crossJoinTables :: [a] -> [b] -> [(a, b)]
crossJoinTables [] _ = []
crossJoinTables _ [] = []
crossJoinTables (table1Head:table1Tail) table2 =
    map (\x -> (table1Head, x)) table2 ++ (crossJoinTables table1Tail table2)

joinGroupedTables :: (Ord o) =>
    (a -> o)
    -> [[a]]
    -> Maybe a
    -> (b -> o)
    -> [[b]]
    -> Maybe b
    -> [(a, b)]
joinGroupedTables _ [] Nothing      _ _  _       = []
joinGroupedTables _ _  _            _ [] Nothing = []

joinGroupedTables _ [] (Just null1) _ tableGroups2 _ = zip (repeat null1) (concat tableGroups2)
joinGroupedTables _ tableGroups1 _ _ [] (Just null2) = zip (concat tableGroups1) (repeat null2)

joinGroupedTables
    joinOrdFunc1
    tableGroups1@(headTableGroup1:tableGroupsTail1)
    maybeNull1
    
    joinOrdFunc2
    tableGroups2@(headTableGroup2:tableGroupsTail2)
    maybeNull2 =
    let
        headRow1 = head headTableGroup1
        headRow2 = head headTableGroup2
    in
        case joinOrdFunc1 headRow1 `compare` joinOrdFunc2 headRow2 of
            -- drop the 1st group if its smaller
            LT ->
                let joinRemainder =
                        joinGroupedTables
                            joinOrdFunc1 tableGroupsTail1 maybeNull1
                            joinOrdFunc2 tableGroups2 maybeNull2 in
                case maybeNull2 of
                    Nothing    -> joinRemainder
                    Just null2 -> zip headTableGroup1 (repeat null2) ++ joinRemainder
            
            -- drop the 2nd group if its smaller
            GT ->
                let joinRemainder =
                        joinGroupedTables
                            joinOrdFunc1 tableGroups1 maybeNull1
                            joinOrdFunc2 tableGroupsTail2 maybeNull2 in
                case maybeNull1 of
                    Nothing    -> joinRemainder
                    Just null1 -> zip (repeat null1) headTableGroup2 ++ joinRemainder
            
            -- the two groups are equal so permute
            _  ->
                let joinRemainder =
                        joinGroupedTables
                            joinOrdFunc1 tableGroupsTail1 maybeNull1
                            joinOrdFunc2 tableGroupsTail2 maybeNull2 in
                crossJoinTables headTableGroup1 headTableGroup2 ++ joinRemainder