@@ -216,26 +216,18 @@ def estimate(
216
216
raise ValidationError ("The h value must be the same per period and stratum." )
217
217
218
218
# death(death_marker=1) count must be less than sample(sample_marker=1)
219
- if death_marker_col is not None :
220
- death_df = (
221
- input_df .filter ((col (death_marker_col ) == 1 ))
222
- .groupBy ([period_col , strata_col ])
223
- .agg (count (col (death_marker_col )))
224
- )
225
- sample_df = (
226
- input_df .filter ((col (sample_marker_col ) == 1 ))
227
- .groupBy ([period_col , strata_col ])
228
- .agg (count (col (sample_marker_col )))
229
- )
230
- if (
231
- death_df .join (sample_df , ["period" , "strata" ], "left" )
232
- .fillna (0 , ["count(sample_inclusion_marker)" ])
233
- .filter (
234
- (col ("count(death_marker)" )) > (col ("count(sample_inclusion_marker)" ))
235
- )
219
+ if (
220
+ death_marker_col is not None
221
+ and (
222
+ input_df .groupBy ([period_col , strata_col ])
223
+ .agg (sum (col (death_marker_col )), sum (col (sample_marker_col )))
224
+ .fillna (0 , ["sum(sample_inclusion_marker)" ])
225
+ .filter (col ("sum(death_marker)" ) > col ("sum(sample_inclusion_marker)" ))
236
226
.count ()
237
- ) >= 1 :
238
- raise ValidationError ("The death count must be less than sample count." )
227
+ )
228
+ >= 1
229
+ ):
230
+ raise ValidationError ("The death count must be less than sample count." )
239
231
240
232
# --- prepare our working data frame ---
241
233
col_list = [
0 commit comments