jaror, 1 month ago Admittedly it gets more complicated when summing two things at the same time: let Pair dnormMean dnormNormMean = fold (Pair <$> dimap (\(Pair _ dnormI) -> dnormI) (/ fromIntegral cc) sum <*> dimap (\(Pair normBti dnormI) -> normBti * dnormI) (/ fromIntegral cc) sum) $ map (\i -> Pair (((inp ! (off + i)) - meanBt) * rstdBt) ((weight ! i) * (dout ! (off + i)))) [0 .. cc - 1] float dnorm_mean = 0.0f; float dnorm_norm_mean = 0.0f; for (int i = 0; i < C; i++) { float norm_bti = (inp_bt[i] - mean_bt) * rstd_bt; float dnorm_i = weight[i] * dout_bt[i]; dnorm_mean += dnorm_i; dnorm_norm_mean += dnorm_i * norm_bti; } dnorm_mean = dnorm_mean / C; dnorm_norm_mean = dnorm_norm_mean / C;
Admittedly it gets more complicated when summing two things at the same time:
let Pair dnormMean dnormNormMean = fold (Pair <$> dimap (\(Pair _ dnormI) -> dnormI) (/ fromIntegral cc) sum <*> dimap (\(Pair normBti dnormI) -> normBti * dnormI) (/ fromIntegral cc) sum) $ map (\i -> Pair (((inp ! (off + i)) - meanBt) * rstdBt) ((weight ! i) * (dout ! (off + i)))) [0 .. cc - 1]
float dnorm_mean = 0.0f; float dnorm_norm_mean = 0.0f; for (int i = 0; i < C; i++) { float norm_bti = (inp_bt[i] - mean_bt) * rstd_bt; float dnorm_i = weight[i] * dout_bt[i]; dnorm_mean += dnorm_i; dnorm_norm_mean += dnorm_i * norm_bti; } dnorm_mean = dnorm_mean / C; dnorm_norm_mean = dnorm_norm_mean / C;