Skip to content

Commit 44eeb9a

Browse files
authoredOct 12, 2021
Merge pull request #44 from ONSdigital/more-use-of-filter-back-data-in-imputation
Use filter_back_data when calculating ratios
2 parents 08dac61 + ff11cd3 commit 44eeb9a

File tree

2 files changed

+20
-19
lines changed

2 files changed

+20
-19
lines changed
 

‎pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "statistical_methods_library"
3-
version = "3.5.2"
3+
version = "3.5.3"
44
description = ""
55
authors = ["Your Name <you@example.com>"]
66
license = "MIT"

‎statistical_methods_library/imputation.py

+19-18
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,9 @@ def validate_df(
282282
# Cache the prepared back data df since we'll need a few differently
283283
# filtered versions
284284
prepared_back_data_df = None
285+
# Store the value for the period prior to the start of imputation.
286+
# Stored as a value to avoid a join in output creation.
287+
prior_period = None
285288

286289
# --- Prepare DF ---
287290

@@ -300,15 +303,23 @@ def prepare(df: DataFrame) -> DataFrame:
300303
.withColumn("next_period", calculate_next_period(col("period")))
301304
)
302305

306+
nonlocal prior_period
307+
# We know this will be a single value so use collect as then we
308+
# can filter directly.
309+
prior_period = prepared_df.selectExpr("min(previous_period)").collect()[0][
310+
0
311+
]
312+
303313
nonlocal prepared_back_data_df
304314
if back_data_df:
305315
prepared_back_data_df = (
306316
select_cols(
307-
back_data_df.join(
308-
prepared_df.selectExpr("min(previous_period)"),
309-
[col(period_col) == col("min(previous_period)")],
310-
"inner",
311-
).filter(col(marker_col) != lit(Marker.BACKWARD_IMPUTE.value))
317+
back_data_df.filter(
318+
(
319+
(col(period_col) == lit(prior_period))
320+
& (col(marker_col) != lit(Marker.BACKWARD_IMPUTE.value))
321+
)
322+
)
312323
)
313324
.drop("target")
314325
.withColumn(
@@ -322,12 +333,9 @@ def prepare(df: DataFrame) -> DataFrame:
322333
prepared_back_data_df = prepared_df.filter(col(period_col).isNull())
323334

324335
prepared_back_data_df = prepared_back_data_df.localCheckpoint(eager=True)
325-
326336
# Ratio calculation needs all the responses from the back data
327337
prepared_df = prepared_df.unionByName(
328-
prepared_back_data_df.filter(
329-
col("marker") == lit(Marker.RESPONSE.value)
330-
)
338+
filter_back_data(col("marker") == lit(Marker.RESPONSE.value))
331339
)
332340

333341
return calculate_ratios(prepared_df)
@@ -613,7 +621,7 @@ def forward_impute_from_construction(df: DataFrame) -> DataFrame:
613621
# --- Utility functions ---
614622
def create_output(df: DataFrame) -> DataFrame:
615623
return select_cols(
616-
df.join(prepared_back_data_df, ["period"], "leftanti"), reversed=False
624+
df.filter(col("period") != lit(prior_period)), reversed=False
617625
).withColumnRenamed("output", output_col)
618626

619627
def select_cols(df: DataFrame, reversed: bool = True) -> DataFrame:
@@ -641,14 +649,7 @@ def calculate_next_period(period: Column) -> Column:
641649
).otherwise((period.cast("int") + 1).cast("string"))
642650

643651
def filter_back_data(filter_col: Column) -> DataFrame:
644-
nonlocal prepared_back_data_df
645-
filtered_df = prepared_back_data_df.filter(filter_col).localCheckpoint(
646-
eager=True
647-
)
648-
prepared_back_data_df = prepared_back_data_df.join(
649-
filtered_df, ["ref", "period"], "leftanti"
650-
).localCheckpoint(eager=True)
651-
return filtered_df
652+
return prepared_back_data_df.filter(filter_col).localCheckpoint(eager=True)
652653

653654
# ----------
654655

0 commit comments

Comments
 (0)
Please sign in to comment.