-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathcausal_test_outcome.py
115 lines (89 loc) · 4.77 KB
/
causal_test_outcome.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
# pylint: disable=too-few-public-methods
"""This module contains the CausalTestOutcome abstract class, as well as the concrete extension classes:
ExactValue, Positive, Negative, SomeEffect, NoEffect"""
from abc import ABC, abstractmethod
from collections.abc import Iterable
import numpy as np
from causal_testing.testing.causal_test_result import CausalTestResult
class CausalTestOutcome(ABC):
"""An abstract class representing an expected causal effect."""
@abstractmethod
def apply(self, res: CausalTestResult) -> bool:
"""Abstract apply method that should return a bool representing if the result meets the outcome
:param res: CausalTestResult to be checked
:return: Bool that is true if outcome is met
"""
def __str__(self) -> str:
return type(self).__name__
class SomeEffect(CausalTestOutcome):
"""An extension of TestOutcome representing that the expected causal effect should not be zero."""
def apply(self, res: CausalTestResult) -> bool:
if res.test_value.type == "risk_ratio":
return any(
1 < ci_low < ci_high or ci_low < ci_high < 1 for ci_low, ci_high in zip(res.ci_low(), res.ci_high()))
if res.test_value.type in ('coefficient', 'ate'):
return any(
0 < ci_low < ci_high or ci_low < ci_high < 0 for ci_low, ci_high in zip(res.ci_low(), res.ci_high()))
raise ValueError(f"Test Value type {res.test_value.type} is not valid for this TestOutcome")
class NoEffect(CausalTestOutcome):
"""An extension of TestOutcome representing that the expected causal effect should be zero."""
def __init__(self, atol: float = 1e-10, ctol: float = 0.05):
"""
:param atol: Arithmetic tolerance. The test will pass if the absolute value of the causal effect is less than
atol.
:param ctol: Categorical tolerance. The test will pass if this proportion of categories pass.
"""
self.atol = atol
self.ctol = ctol
def apply(self, res: CausalTestResult) -> bool:
if res.test_value.type == "risk_ratio":
return any(ci_low < 1 < ci_high or np.isclose(value, 1.0, atol=self.atol) for ci_low, ci_high, value in
zip(res.ci_low(), res.ci_high(), res.test_value.value))
if res.test_value.type in ('coefficient', 'ate'):
value = res.test_value.value if isinstance(res.ci_high(), Iterable) else [res.test_value.value]
return (
sum(
not ((ci_low < 0 < ci_high) or abs(v) < self.atol)
for ci_low, ci_high, v in zip(res.ci_low(), res.ci_high(), value)
)
/ len(value)
< self.ctol
)
raise ValueError(f"Test Value type {res.test_value.type} is not valid for this TestOutcome")
class ExactValue(SomeEffect):
"""An extension of TestOutcome representing that the expected causal effect should be a specific value."""
def __init__(self, value: float, atol: float = None):
self.value = value
if atol is None:
self.atol = abs(value * 0.05)
else:
self.atol = atol
if self.atol < 0:
raise ValueError("Tolerance must be an absolute value.")
def apply(self, res: CausalTestResult) -> bool:
if res.ci_valid():
return super().apply(res) and np.isclose(res.test_value.value, self.value, atol=self.atol)
return np.isclose(res.test_value.value, self.value, atol=self.atol)
def __str__(self):
return f"ExactValue: {self.value}±{self.atol}"
class Positive(SomeEffect):
"""An extension of TestOutcome representing that the expected causal effect should be positive."""
def apply(self, res: CausalTestResult) -> bool:
if res.ci_valid() and not super().apply(res):
return False
if res.test_value.type in {"ate", "coefficient"}:
return bool(res.test_value.value[0] > 0)
if res.test_value.type == "risk_ratio":
return bool(res.test_value.value[0] > 1)
raise ValueError(f"Test Value type {res.test_value.type} is not valid for this TestOutcome")
class Negative(SomeEffect):
"""An extension of TestOutcome representing that the expected causal effect should be negative."""
def apply(self, res: CausalTestResult) -> bool:
if res.ci_valid() and not super().apply(res):
return False
if res.test_value.type in {"ate", "coefficient"}:
return bool(res.test_value.value[0] < 0)
if res.test_value.type == "risk_ratio":
return bool(res.test_value.value[0] < 1)
# Dead code but necessary for pylint
raise ValueError(f"Test Value type {res.test_value.type} is not valid for this TestOutcome")