Skip to content

Commit

Permalink
- test-statistics.arr: added tests for the new stats functions brownp…
Browse files Browse the repository at this point in the history
…lt#1732

- statistics.arr: added exceptions for t-test-{pooled, independent}
  • Loading branch information
ds26gte committed Apr 13, 2024
1 parent a52ef2d commit 2ba6dd1
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 14 deletions.
30 changes: 19 additions & 11 deletions src/arr/trove/statistics.arr
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ fun t-test-paired(l1 :: List, l2 :: List) -> Number:
if n1 <> n2:
raise(E.message-exception("t-test-paired: input lists must have equal lengths"))
else if n1 == 0:
raise(E.message-exception("t-test-paired: input lists should have at least one element"))
raise(E.message-exception("t-test-paired: input lists must have at least one element"))
else:
diffs = map2(lam(x1, x2): x1 - x2 end, l1, l2)
diffs-mean = mean(diffs)
Expand All @@ -242,22 +242,30 @@ fun t-test-pooled(l1 :: List, l2 :: List) -> Number:
doc: "t-test-pooled"
n1 = l1.length()
n2 = l2.length()
m1 = mean(l1)
m2 = mean(l2)
v1 = variance-sample(l1)
v2 = variance-sample(l2)
(m1 - m2) / (((((n1 - 1) * num-expt(v1, 2)) + ((n2 - 1) * num-expt(v2, 2))) / ((n1 + n2) - 2)) * num-sqrt((1 / n1) + (1 / n2)))
if (n1 == 0) or (n2 == 0):
raise(E.message-exception("t-test-pooled: input lists must have at least one element"))
else:
m1 = mean(l1)
m2 = mean(l2)
v1 = variance-sample(l1)
v2 = variance-sample(l2)
(m1 - m2) / (((((n1 - 1) * num-expt(v1, 2)) + ((n2 - 1) * num-expt(v2, 2))) / ((n1 + n2) - 2)) * num-sqrt((1 / n1) + (1 / n2)))
end
end

fun t-test-independent(l1 :: List, l2 :: List) -> Number:
doc: "t-test-independent"
n1 = l1.length()
n2 = l2.length()
m1 = mean(l1)
m2 = mean(l2)
v1 = variance-sample(l1)
v2 = variance-sample(l2)
(m1 - m2) / num-sqrt((v1 / n1) + (v2 / n2))
if (n1 == 0) or (n2 == 0):
raise(E.message-exception("t-test-independent: input lists must have at least one element"))
else:
m1 = mean(l1)
m2 = mean(l2)
v1 = variance-sample(l1)
v2 = variance-sample(l2)
(m1 - m2) / num-sqrt((v1 / n1) + (v2 / n2))
end
end

fun chi-square(obs :: List, exp :: List) -> Number:
Expand Down
28 changes: 25 additions & 3 deletions tests/pyret/tests/test-statistics.arr
Original file line number Diff line number Diff line change
Expand Up @@ -49,18 +49,40 @@ check "numeric helpers":
modes([list: -1, 2, -1, 2, -1]) is [list: -1]
modes([list: 1, 1, 2, 2, 3, 3, 3]) is [list: 3]

variance([list:]) raises "empty"
variance([list: 5]) is 0
variance([list: 3, 4, 5, 6, 7]) is%(within(0.01)) 2.0
variance([list: 1, 1, 1, 1]) is-roughly ~0

stdev([list:]) raises "empty"
stdev([list: 5]) is 0

stdev([list: 3, 4, 5, 6, 7]) is%(within(0.01)) 1.41
stdev([list: 1, 1, 1, 1]) is-roughly ~0
stdev([list:]) raises "empty"


variance-sample([list: 3, 4, 5, 6, 7]) is%(within(0.01)) (10 / 4)
variance-sample([list: 3]) raises "division by zero"
variance-sample([list: 1, 1, 1, 1]) is-roughly ~0
variance-sample([list:]) raises "empty"

stdev-sample([list: 3, 4, 5, 6, 7]) is%(within(0.01)) num-sqrt(10 / 4)
stdev-sample([list: 3]) raises "division by zero"
stdev-sample([list: 1, 1, 1, 1]) is-roughly ~0
stdev-sample([list:]) raises "empty"

t-test-paired([list: 1], [list: 2, 3]) raises "lists must have equal lengths"
t-test-paired([list:], [list:]) raises "lists must have at least one element"
t-test-paired([list: 1, 2, 3], [list: 4, 6, 8]) is%(within(0.01)) -6.928

t-test-pooled([list:], [list: 1, 2, 3]) raises "lists must have at least one element"
t-test-pooled([list: 1, 2, 3], [list: 4, 5, 6]) is%(within(0.01)) -3.674
t-test-pooled([list: 1, 2, 3], [list: 4, 5, 6, 7]) is%(within(0.01)) -2.217

t-test-independent([list:], [list: 1, 2, 3]) raises "lists must have at least one element"
t-test-independent([list: 1, 2, 3], [list: 4, 5, 6]) is%(within(0.01)) -3.674
t-test-independent([list: 1, 2, 3], [list: 4, 5, 6, 7]) is%(within(0.01)) -4.041

chi-square([list: 1, 2, 3, 4], [list: 1, 2, 3, 4]) is 0
chi-square([list: 1, 2, 3, 4], [list: 0.9, 1.8, 3.5, 4.7]) is%(within(0.01)) 0.209
end

check "linear regression":
Expand Down

0 comments on commit 2ba6dd1

Please sign in to comment.