-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathsurrogate_search_algorithms.py
116 lines (93 loc) · 4.91 KB
/
surrogate_search_algorithms.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
"""Module containing implementation of search algorithm for surrogate search """
# Fitness functions are required to be iteratively defined, including all variables within.
from operator import itemgetter
from pygad import GA
from causal_testing.specification.causal_specification import CausalSpecification
from causal_testing.testing.estimators import CubicSplineRegressionEstimator
from causal_testing.surrogate.causal_surrogate_assisted import SearchAlgorithm
class GeneticSearchAlgorithm(SearchAlgorithm):
"""Implementation of SearchAlgorithm class. Implements genetic search algorithm for surrogate models."""
def __init__(self, delta=0.05, config: dict = None) -> None:
super().__init__()
self.delta = delta
self.config = config
self.contradiction_functions = {
"positive": lambda x: -1 * x,
"negative": lambda x: x,
"no_effect": abs,
"some_effect": lambda x: abs(1 / x),
}
# pylint: disable=too-many-locals
def search(
self, surrogate_models: list[CubicSplineRegressionEstimator], specification: CausalSpecification
) -> list:
solutions = []
for surrogate in surrogate_models:
contradiction_function = self.contradiction_functions[surrogate.expected_relationship]
# The GA fitness function after including required variables into the function's scope
# Unused arguments are required for pygad's fitness function signature
# pylint: disable=cell-var-from-loop
def fitness_function(ga, solution, idx): # pylint: disable=unused-argument
surrogate.control_value = solution[0] - self.delta
surrogate.treatment_value = solution[0] + self.delta
adjustment_dict = {}
for i, adjustment in enumerate(surrogate.adjustment_set):
adjustment_dict[adjustment] = solution[i + 1]
ate = surrogate.estimate_ate_calculated(adjustment_dict)
if len(ate) > 1:
raise ValueError(
"Multiple ate values provided but currently only single values supported in this method")
return contradiction_function(ate[0])
gene_types, gene_space = self.create_gene_types(surrogate, specification)
ga = GA(
num_generations=200,
num_parents_mating=4,
fitness_func=fitness_function,
sol_per_pop=10,
num_genes=1 + len(surrogate.adjustment_set),
gene_space=gene_space,
gene_type=gene_types,
)
if self.config is not None:
for k, v in self.config.items():
if k == "gene_space":
raise ValueError(
"Gene space should not be set through config. This is generated from the causal "
"specification"
)
setattr(ga, k, v)
ga.run()
solution, fitness, _ = ga.best_solution()
solution_dict = {}
solution_dict[surrogate.treatment] = solution[0]
for idx, adj in enumerate(surrogate.adjustment_set):
solution_dict[adj] = solution[idx + 1]
solutions.append((solution_dict, fitness, surrogate))
return max(solutions, key=itemgetter(1)) # This can be done better with fitness normalisation between edges
@staticmethod
def create_gene_types(
surrogate_model: CubicSplineRegressionEstimator, specification: CausalSpecification
) -> tuple[list, list]:
"""Generate the gene_types and gene_space for a given fitness function and specification
:param surrogate_model: Instance of a CubicSplineRegressionEstimator
:param specification: The Causal Specification (combination of Scenario and Causal Dag)"""
var_space = {}
var_space[surrogate_model.treatment] = {}
for adj in surrogate_model.adjustment_set:
var_space[adj] = {}
for relationship in list(specification.scenario.constraints):
rel_split = str(relationship).split(" ")
if rel_split[0] in var_space:
if rel_split[1] == ">=":
var_space[rel_split[0]]["low"] = int(rel_split[2])
elif rel_split[1] == "<=":
var_space[rel_split[0]]["high"] = int(rel_split[2])
gene_space = []
gene_space.append(var_space[surrogate_model.treatment])
for adj in surrogate_model.adjustment_set:
gene_space.append(var_space[adj])
gene_types = []
gene_types.append(specification.scenario.variables.get(surrogate_model.treatment).datatype)
for adj in surrogate_model.adjustment_set:
gene_types.append(specification.scenario.variables.get(adj).datatype)
return gene_types, gene_space