module HLearn.Models.Classifiers.Experimental.Boosting.MonoidBoost
    where
import Control.Applicative
import Data.List
import qualified Data.Foldable as F
import qualified Data.Sequence as Seq
import Data.Sequence (fromList)
import GHC.TypeLits
import Debug.Trace
import Test.QuickCheck
import HLearn.Algebra
import HLearn.Models.Distributions.Visualization.Gnuplot
import HLearn.Models.Distributions
import HLearn.Models.Classifiers.Common
data MonoidBoost (k::Nat) basemodel = MonoidBoost
    { dataL :: Seq.Seq (Datapoint basemodel)
    , modelL :: Seq.Seq basemodel
    , weightL :: Seq.Seq (Ring basemodel)
    , boost_numdp :: Int
    }
    
deriving instance (Read (Datapoint basemodel), Read (Ring basemodel), Read basemodel) => Read (MonoidBoost k basemodel)
deriving instance (Show (Datapoint basemodel), Show (Ring basemodel), Show basemodel) => Show (MonoidBoost k basemodel)
deriving instance (Eq   (Datapoint basemodel), Eq   (Ring basemodel), Eq   basemodel) => Eq   (MonoidBoost k basemodel)
deriving instance (Ord  (Datapoint basemodel), Ord  (Ring basemodel), Ord  basemodel) => Ord  (MonoidBoost k basemodel)
instance 
    ( HomTrainer basemodel
    , Arbitrary (Datapoint basemodel)
    , SingI k
    ) => Arbitrary (MonoidBoost k basemodel) 
        where
    arbitrary = train <$> listOf arbitrary    
testassociativity = quickCheck ((\m1 m2 m3 -> m1<>(m2<>m3)==(m1<>m2)<>m3) 
    :: MonoidBoost 3 (Normal Rational)
    -> MonoidBoost 3 (Normal Rational)
    -> MonoidBoost 3 (Normal Rational)
    -> Bool
    )
leave :: Int -> Seq.Seq a -> Seq.Seq a
leave k xs = Seq.drop (Seq.length xs  k) xs
instance 
    ( HomTrainer basemodel
    , SingI k
    ) => Monoid (MonoidBoost k basemodel) 
        where
    mempty = MonoidBoost mempty mempty mempty 0
    mb1 `mappend` mb2 = MonoidBoost
        { dataL     = dataL'
        , modelL    = modelL mb1 <> newmodel <> modelL mb2
        , weightL   = mempty
        , boost_numdp     = boost_numdp'
        }
        where
            boost_numdp' = boost_numdp mb1 + boost_numdp mb2
            dataL' = dataL mb1 <> dataL mb2
            
            newmodel = Seq.fromList $ newmodels $ leave (2*k) (dataL mb1) <> Seq.take (2*k) (dataL mb2)
            newmodels xs = if Seq.length xs >= modelsize
                then (train (Seq.take modelsize xs)):(newmodels $ Seq.drop 1 xs)
                else []
            modelsize = 2*k+1
            k = fromIntegral $ fromSing (sing::Sing k)
instance 
    ( SingI k
    , HomTrainer basemodel
    ) => HomTrainer (MonoidBoost k basemodel) 
        where
    type Datapoint (MonoidBoost k basemodel) = Datapoint basemodel
    train1dp dp = MonoidBoost
        { dataL = mempty |> dp
        , modelL = mempty
        , weightL = mempty
        , boost_numdp = 1
        }
    
instance Probabilistic (MonoidBoost k basemodel) where
    type Probability (MonoidBoost k basemodel) = Probability basemodel
instance
    ( ProbabilityClassifier basemodel
    , Monoid (ResultDistribution basemodel)
    ) => ProbabilityClassifier (MonoidBoost k basemodel)
        where
    type ResultDistribution (MonoidBoost k basemodel) = ResultDistribution basemodel    
    probabilityClassify mb dp = reduce $ fmap (flip probabilityClassify dp) $ modelL mb
    
instance 
    ( PDF basemodel
    , Fractional (Probability basemodel)
    ) => PDF (MonoidBoost k basemodel)
        where
    pdf mb dp = ave $ fmap (flip pdf dp) $ modelL mb
        where
            ave xs = (F.foldl1 (+) xs) / (fromIntegral $ Seq.length xs)
instance 
    ( PlottableDistribution basemodel
    , Fractional (Probability basemodel)
    ) => PlottableDistribution (MonoidBoost k basemodel)
        where
    
    samplePoints mb = concat $ map samplePoints $ F.toList $ modelL mb
    plotType _ = plotType (undefined :: basemodel)