Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add mc unittest data for RoM #187

Merged
merged 15 commits into from
Jun 3, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 32 additions & 57 deletions statistical_methods_library/imputation/engine.py
Original file line number Diff line number Diff line change
@@ -158,6 +158,8 @@ def impute(
weight_periodicity_multiplier: Multiplied by the periodicity of the
dataset to calculate the previous period when finding the previous
links for weighting.
manual_construction_col: The name of the column containing the
construction value.
ratio_calculator_params: Any extra keyword arguments to the engine are
passed to the specified ratio calculators as keyword args and are
otherwise ignored by this function. Please see the specified ratio
@@ -243,7 +245,6 @@ def impute(
# only if manual_construction_col is not None.
if manual_construction_col:
input_params["manual_const"] = manual_construction_col
fill_values_mc = {}

if back_data_df:
if not isinstance(back_data_df, DataFrame):
@@ -317,31 +318,7 @@ def impute(
prior_period_df = prepared_df.selectExpr(
"min(previous_period) AS prior_period"
).localCheckpoint(eager=False)
if manual_construction_col:
# Set manual construction value as output
# and set marker as MC
mc_df = prepared_df.withColumn(
"marker",
when(
(col("manual_const").isNotNull()) & (col("output").isNull()),
lit(Marker.MANUAL_CONSTRUCTION.value),
).otherwise(col("marker")),
).withColumn(
"output",
when(
(col("manual_const").isNotNull()) & (col("output").isNull()),
col("manual_const"),
).otherwise(col("output")),
)
manual_construction_df = mc_df.filter(
(col("marker") == Marker.MANUAL_CONSTRUCTION.value)
)
# Filter out the MC data so
# it will be not inculded in the link calculations
prepared_df = mc_df.filter(
col("marker").isNull()
| (~(col("marker") == Marker.MANUAL_CONSTRUCTION.value))
)

if back_data_df:
validated_back_data_df = validate_dataframe(
back_data_df, back_input_params, type_mapping, ["ref", "period", "grouping"]
@@ -371,7 +348,7 @@ def impute(
def calculate_ratios():
# This allows us to return early if we have nothing to do
nonlocal prepared_df
nonlocal fill_values_mc

ratio_calculators = []
if "forward" in prepared_df.columns:
prepared_df = (
@@ -471,8 +448,6 @@ def calculate_ratios():

prepared_df = prepared_df.fillna(fill_values)

fill_values_mc = fill_values

if link_filter:
prepared_df = prepared_df.join(
ratio_calculation_df.select(
@@ -576,34 +551,6 @@ def calculate_weighted_link(link_name):

calculate_ratios()

if manual_construction_col:
# populate link, count, default information
# for manual_construction data
# Get the required additional output columns
mc_cols = manual_construction_df.columns
mc_additional_cols = []
for key in output_col_mapping.keys():
# Remove growth_forward and growth_backward
# as it should be null for non responder
if (key not in mc_cols) and (
key not in ["growth_forward", "growth_backward"]
):
mc_additional_cols.append(key)
manual_construction_df = (
manual_construction_df.alias("mc")
.join(
prepared_df.dropDuplicates(["period", "grouping"]),
["period", "grouping"],
"leftouter",
)
.select(
*(f"mc.{name}" for name in mc_cols),
*mc_additional_cols,
)
)
# Fill null additional columns value with default value.
manual_construction_df = manual_construction_df.fillna(fill_values_mc)

# Caching for both imputed and unimputed data.
imputed_df = None
null_response_df = None
@@ -810,6 +757,34 @@ def forward_impute_from_construction(df: DataFrame) -> DataFrame:
df, "forward", Marker.FORWARD_IMPUTE_FROM_CONSTRUCTION, True
)

if manual_construction_col:
# Set manual construction value as output
# and marker as MC
mc_df = prepared_df.withColumn(
"marker",
when(
(col("manual_const").isNotNull()) & (col("output").isNull()),
lit(Marker.MANUAL_CONSTRUCTION.value),
).otherwise(col("marker")),
).withColumn(
"output",
when(
(col("manual_const").isNotNull()) & (col("output").isNull()),
col("manual_const"),
).otherwise(col("output")),
)

# Filter out identifiers with a MC value.So it prevents the FIR from
# being issued against the targeted FIMC. This MC data will be merged with
# the main df prior to the forward_impute_from_manual_construction stage.
manual_construction_df = mc_df.filter(
(col("marker") == Marker.MANUAL_CONSTRUCTION.value)
)
prepared_df = mc_df.filter(
col("marker").isNull()
| (~(col("marker") == Marker.MANUAL_CONSTRUCTION.value))
)

df = prepared_df
for stage in (
forward_impute_from_response,
Original file line number Diff line number Diff line change
@@ -2,4 +2,4 @@ identifier,date,group,question,other,manual_construction
1234,"202105",900,,78,
1235,"202105",100,,81,
1236,"202105",100,2113,81,
1237,"202105",200,,81,3189
1237,"202105",200,,81,3189
Original file line number Diff line number Diff line change
@@ -2,4 +2,4 @@ identifier,date,group,output,marker,forward,backward,construction,count_forward,
1234,202105,900,10,FIMC,1,1,1,0,0,0,true,true,true
1235,202105,100,20,FIMC,1,1,26.08641975,0,0,1,true,true,false
1236,202105,100,2113,R,1,1,26.08641975,0,0,1,true,true,false
1237,202105,200,3189,MC,1,1,1,0,0,0,true,true,true
1237,202105,200,3189,MC,1,1,1,0,0,0,true,true,true
Original file line number Diff line number Diff line change
@@ -10,4 +10,4 @@ identifier,date,group,question,other,manual_construction
30003,202003,100,6492,7,
30004,202001,100,,81,4321
30004,202002,100,2113,81,
30004,202003,100,,81,3189
30004,202003,100,,81,3189
Original file line number Diff line number Diff line change
@@ -10,4 +10,4 @@ identifier,date,group,growth_forward,growth_backward,forward,backward,constructi
30003,202003,100,3.686541738,,1.526947,1,103.0153846,6492,R,3,0,3,false,true,false
30004,202001,100,,,1,2.196577972,194.6,4321,MC,0,3,3,true,false,false
30004,202002,100,,,0.652198238,1.866715325,90.8436019,2113,R,3,3,4,false,false,false
30004,202003,100,,,1.526947,1,103.0153846,3189,MC,3,0,3,false,true,false
30004,202003,100,,,1.526947,1,103.0153846,3189,MC,3,0,3,false,true,false
Original file line number Diff line number Diff line change
@@ -14,4 +14,4 @@ identifier,date,group,growth_forward,growth_backward,forward,backward,constructi
40004,202001,100,,0.521655144,1,2.053506032,128.0208333,5131,R,0,4,4,true,false,false
40004,202002,100,1.916975248,,0.916179196,1.507228985,120.7395833,9836,R,4,3,4,false,false,false
40004,202003,100,,,3.003997558,1.551291583,115.6929824,7525,MC,3,3,3,false,false,false
40004,202004,100,,,0.645360538,1,73.71052632,4856.338052,FIMC,3,0,3,false,true,false
40004,202004,100,,,0.645360538,1,73.71052632,4856.338052,FIMC,3,0,3,false,true,false
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
identifier,date,group,question,other,manual_construction
110001,202001,100,9244,89,
110001,202002,100,8916,89,
110001,202003,100,6194,89,
110002,202001,100,4826,83,
110002,202002,100,5903,83,
110002,202003,100,4743,83,
110003,202001,100,7586,4,
110003,202002,100,1016,4,
110003,202003,100,1429,4,
110004,202001,100,3975,76,
110004,202002,100,3044,76,
110004,202003,100,,76,
110005,202001,200,5217,27,
110005,202002,200,7016,27,
110005,202003,200,9940,27,
110006,202001,200,5325,42,
110006,202002,200,7747,42,
110006,202003,200,6685,42,
110007,202001,200,5496,19,
110007,202002,200,,19,1010
110007,202003,200,,19,
110008,202001,200,,43,
110008,202002,200,3913,43,
110008,202003,200,6013,43,
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
identifier,date,group,output,marker,unweighted_forward,unweighted_backward,unweighted_construction,forward,backward,construction,count_forward,count_backward,count_construction,default_forward,default_backward,default_construction
110001,202001,100,9244,R,1,1.357646062,101.710317,1,1.442523031,76.03985873,0,4,4,TRUE,FALSE,FALSE
110002,202001,100,4826,R,1,1.357646062,101.710317,1,1.442523031,76.03985873,0,4,4,TRUE,FALSE,FALSE
110003,202001,100,7586,R,1,1.357646062,101.710317,1,1.442523031,76.03985873,0,4,4,TRUE,FALSE,FALSE
110004,202001,100,3975,R,1,1.357646062,101.710317,1,1.442523031,76.03985873,0,4,4,TRUE,FALSE,FALSE
110005,202001,200,5217,R,1,0.714082504,182.25,1,1.854216752,144.865,0,2,3,TRUE,FALSE,FALSE
110006,202001,200,5325,R,1,0.714082504,182.25,1,1.854216752,144.865,0,2,3,TRUE,FALSE,FALSE
110007,202001,200,5496,R,1,0.714082504,182.25,1,1.854216752,144.865,0,2,3,TRUE,FALSE,FALSE
110008,202001,200,7255.550151,BI,1,0.714082504,182.25,1,1.854216752,144.865,0,2,3,TRUE,FALSE,FALSE
110001,202002,100,8916,R,0.736568998,1.280527252,74.91666667,0.852534499,1.580013626,98.16768333,4,3,4,FALSE,FALSE,FALSE
110002,202002,100,5903,R,0.736568998,1.280527252,74.91666667,0.852534499,1.580013626,98.16768333,4,3,4,FALSE,FALSE,FALSE
110003,202002,100,1016,R,0.736568998,1.280527252,74.91666667,0.852534499,1.580013626,98.16768333,4,3,4,FALSE,FALSE,FALSE
110004,202002,100,3044,R,0.736568998,1.280527252,74.91666667,0.852534499,1.580013626,98.16768333,4,3,4,FALSE,FALSE,FALSE
110005,202002,200,7016,R,1.400398406,0.824984539,166.75,1.389168203,0.74779677,110.7264285,2,3,3,FALSE,FALSE,FALSE
110006,202002,200,7747,R,1.400398406,0.824984539,166.75,1.389168203,0.74779677,110.7264285,2,3,3,FALSE,FALSE,FALSE
110007,202002,200,1010,MC,1.400398406,0.824984539,166.75,1.389168203,0.74779677,110.7264285,2,3,3,FALSE,FALSE,FALSE
110008,202002,200,3913,R,1.400398406,0.824984539,166.75,1.389168203,0.74779677,110.7264285,2,3,3,FALSE,FALSE,FALSE
110001,202003,100,6194,R,0.780928323,1,70.26136364,0.946964162,1,79.81873182,3,0,3,FALSE,TRUE,FALSE
110002,202003,100,4743,R,0.780928323,1,70.26136364,0.946964162,1,79.81873182,3,0,3,FALSE,TRUE,FALSE
110003,202003,100,1429,R,0.780928323,1,70.26136364,0.946964162,1,79.81873182,3,0,3,FALSE,TRUE,FALSE
110004,202003,100,2882.558909,FIR,0.780928323,1,70.26136364,0.946964162,1,79.81873182,3,0,3,FALSE,TRUE,FALSE
110005,202003,200,9940,R,1.212143928,1,202.125,1.082364964,1,150.1950965,3,0,3,FALSE,TRUE,FALSE
110006,202003,200,6685,R,1.212143928,1,202.125,1.082364964,1,150.1950965,3,0,3,FALSE,TRUE,FALSE
110007,202003,200,1093.188614,FIMC,1.212143928,1,202.125,1.082364964,1,150.1950965,3,0,3,FALSE,TRUE,FALSE
110008,202003,200,6013,R,1.212143928,1,202.125,1.082364964,1,150.1950965,3,0,3,FALSE,TRUE,FALSE
110001,201901,100,9244,R,1,1.5274,50.3694,,,,,,,,,
110001,201902,100,8916,R,0.9685,1.8795,121.4187,,,,,,,,,
110001,201903,100,6194,R,1.113,1,89.3761,,,,,,,,,
110002,201901,100,4826,R,1,1.5274,50.3694,,,,,,,,,
110002,201902,100,5903,R,0.9685,1.8795,121.4187,,,,,,,,,
110002,201903,100,4743,R,1.113,1,89.3761,,,,,,,,,
110003,201901,100,7586,R,1,1.5274,50.3694,,,,,,,,,
110003,201902,100,1016,R,0.9685,1.8795,121.4187,,,,,,,,,
110003,201903,100,1429,R,1.113,1,89.3761,,,,,,,,,
110004,201901,100,3975,R,1,1.5274,50.3694,,,,,,,,,
110004,201902,100,3044,R,0.9685,1.8795,121.4187,,,,,,,,,
110004,201903,100,8437,R,1.113,1,89.3761,,,,,,,,,
110005,201901,200,5217,R,1,2.994350985,107.48,,,,,,,,,
110005,201902,200,7016,R,1.377938182,0.670608989,54.70285714,,,,,,,,,
110005,201903,200,9940,R,0.952585628,1,98.26519337,,,,,,,,,
110006,201901,200,5325,R,1,2.994350985,107.48,,,,,,,,,
110006,201902,200,7747,R,1.377938182,0.670608989,54.70285714,,,,,,,,,
110006,201903,200,6685,R,0.952585628,1,98.26519337,,,,,,,,,
110007,201901,200,5496,R,1,2.994350985,107.48,,,,,,,,,
110007,201902,200,1010,R,1.377938182,0.670608989,54.70285714,,,,,,,,,
110007,201903,200,1235,R,0.952585628,1,98.26519337,,,,,,,,,
110008,201901,200,8272,R,1,2.994350985,107.48,,,,,,,,,
110008,201902,200,3913,R,1.377938182,0.670608989,54.70285714,,,,,,,,,
110008,201903,200,6013,R,0.952585628,1,98.26519337,,,,,,,,,
5 changes: 5 additions & 0 deletions tests/imputation/ratio_of_means.toml
Original file line number Diff line number Diff line change
@@ -21,3 +21,8 @@ link_filter = "identifier != '10003'"

[scenarios.35_BI_BI_R_FI_FI_R_FI_alternating_filtered]
link_filter = "NOT(identifier = '70001' AND date IN ('202003', '202005'))"

[scenarios.36_R_MC_FIMC_weighted]
weight = "0.5"
weight_periodicity_multiplier = 12
manual_construction_col = "manual_construction"
2 changes: 1 addition & 1 deletion tests/imputation/test_engine.py
Original file line number Diff line number Diff line change
@@ -425,7 +425,7 @@ def test_input_data_contains_nulls(fxt_load_test_csv, fxt_spark_session):
impute(input_df=test_dataframe, **params)


def test_back_data_fimc(fxt_load_test_csv, fxt_spark_session):
def test_back_data_mc_fimc(fxt_load_test_csv, fxt_spark_session):
test_dataframe = fxt_load_test_csv(
dataframe_columns,
dataframe_types,