module Statistics.Regression
(
olsRegress
, ols
, rSquare
) where
import Control.Applicative ((<$>))
import Prelude hiding (pred, sum)
import Statistics.Function as F
import Statistics.Matrix hiding (map)
import Statistics.Matrix.Algorithms (qr)
import Statistics.Sample (mean)
import Statistics.Sample.Internal (sum)
import qualified Data.Vector.Generic as G
import qualified Data.Vector.Unboxed as U
import qualified Data.Vector.Unboxed.Mutable as M
olsRegress :: [Vector]
-> Vector
-> (Vector, Double)
olsRegress preds@(_:_) resps
| any (/=n) ls = error $ "predictor vector length mismatch " ++
show lss
| G.length resps /= n = error $ "responder/predictor length mismatch " ++
show (G.length resps, n)
| otherwise = (coeffs, rSquare mxpreds resps coeffs)
where
coeffs = ols mxpreds resps
mxpreds = transpose .
fromVector (length lss + 1) n .
G.concat $ preds ++ [G.replicate n 1]
lss@(n:ls) = map G.length preds
olsRegress _ _ = error "no predictors given"
ols :: Matrix
-> Vector
-> Vector
ols a b
| rs < cs = error $ "fewer rows than columns " ++ show d
| otherwise = solve r (transpose q `multiplyV` b)
where
d@(rs,cs) = dimension a
(q,r) = qr a
solve :: Matrix
-> Vector
-> Vector
solve r b
| n /= l = error $ "row/vector mismatch " ++ show (n,l)
| otherwise = U.create $ do
s <- U.thaw b
rfor n 0 $ \i -> do
si <- (/ unsafeIndex r i i) <$> M.unsafeRead s i
M.unsafeWrite s i si
for 0 i $ \j -> F.unsafeModify s j $ subtract ((unsafeIndex r j i) * si)
return s
where n = rows r
l = U.length b
rSquare :: Matrix
-> Vector
-> Vector
-> Double
rSquare pred resp coeff = 1 r / t
where
r = sum $ flip U.imap resp $ \i x -> square (x p i)
t = sum $ flip U.map resp $ \x -> square (x mean resp)
p i = sum . flip U.imap coeff $ \j -> (* unsafeIndex pred i j)