@@ -282,6 +282,9 @@ def validate_df(
282
282
# Cache the prepared back data df since we'll need a few differently
283
283
# filtered versions
284
284
prepared_back_data_df = None
285
+ # Store the value for the period prior to the start of imputation.
286
+ # Stored as a value to avoid a join in output creation.
287
+ prior_period = None
285
288
286
289
# --- Prepare DF ---
287
290
@@ -300,15 +303,23 @@ def prepare(df: DataFrame) -> DataFrame:
300
303
.withColumn ("next_period" , calculate_next_period (col ("period" )))
301
304
)
302
305
306
+ nonlocal prior_period
307
+ # We know this will be a single value so use collect as then we
308
+ # can filter directly.
309
+ prior_period = prepared_df .selectExpr ("min(previous_period)" ).collect ()[0 ][
310
+ 0
311
+ ]
312
+
303
313
nonlocal prepared_back_data_df
304
314
if back_data_df :
305
315
prepared_back_data_df = (
306
316
select_cols (
307
- back_data_df .join (
308
- prepared_df .selectExpr ("min(previous_period)" ),
309
- [col (period_col ) == col ("min(previous_period)" )],
310
- "inner" ,
311
- ).filter (col (marker_col ) != lit (Marker .BACKWARD_IMPUTE .value ))
317
+ back_data_df .filter (
318
+ (
319
+ (col (period_col ) == lit (prior_period ))
320
+ & (col (marker_col ) != lit (Marker .BACKWARD_IMPUTE .value ))
321
+ )
322
+ )
312
323
)
313
324
.drop ("target" )
314
325
.withColumn (
@@ -322,12 +333,9 @@ def prepare(df: DataFrame) -> DataFrame:
322
333
prepared_back_data_df = prepared_df .filter (col (period_col ).isNull ())
323
334
324
335
prepared_back_data_df = prepared_back_data_df .localCheckpoint (eager = True )
325
-
326
336
# Ratio calculation needs all the responses from the back data
327
337
prepared_df = prepared_df .unionByName (
328
- prepared_back_data_df .filter (
329
- col ("marker" ) == lit (Marker .RESPONSE .value )
330
- )
338
+ filter_back_data (col ("marker" ) == lit (Marker .RESPONSE .value ))
331
339
)
332
340
333
341
return calculate_ratios (prepared_df )
@@ -613,7 +621,7 @@ def forward_impute_from_construction(df: DataFrame) -> DataFrame:
613
621
# --- Utility functions ---
614
622
def create_output (df : DataFrame ) -> DataFrame :
615
623
return select_cols (
616
- df .join ( prepared_back_data_df , [ "period" ], "leftanti" ), reversed = False
624
+ df .filter ( col ( "period" ) != lit ( prior_period ) ), reversed = False
617
625
).withColumnRenamed ("output" , output_col )
618
626
619
627
def select_cols (df : DataFrame , reversed : bool = True ) -> DataFrame :
@@ -641,14 +649,7 @@ def calculate_next_period(period: Column) -> Column:
641
649
).otherwise ((period .cast ("int" ) + 1 ).cast ("string" ))
642
650
643
651
def filter_back_data (filter_col : Column ) -> DataFrame :
644
- nonlocal prepared_back_data_df
645
- filtered_df = prepared_back_data_df .filter (filter_col ).localCheckpoint (
646
- eager = True
647
- )
648
- prepared_back_data_df = prepared_back_data_df .join (
649
- filtered_df , ["ref" , "period" ], "leftanti"
650
- ).localCheckpoint (eager = True )
651
- return filtered_df
652
+ return prepared_back_data_df .filter (filter_col ).localCheckpoint (eager = True )
652
653
653
654
# ----------
654
655
0 commit comments