Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit fc8ec0d

Browse files
committedJul 1, 2022
refactored code for review comment
1 parent b348cd6 commit fc8ec0d

File tree

2 files changed

+12
-20
lines changed

2 files changed

+12
-20
lines changed
 

‎pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "statistical_methods_library"
3-
version = "4.4.0"
3+
version = "4.5.0"
44
description = ""
55
authors = ["Your Name <you@example.com>"]
66
license = "MIT"

‎statistical_methods_library/estimation.py

+11-19
Original file line numberDiff line numberDiff line change
@@ -216,26 +216,18 @@ def estimate(
216216
raise ValidationError("The h value must be the same per period and stratum.")
217217

218218
# death(death_marker=1) count must be less than sample(sample_marker=1)
219-
if death_marker_col is not None:
220-
death_df = (
221-
input_df.filter((col(death_marker_col) == 1))
222-
.groupBy([period_col, strata_col])
223-
.agg(count(col(death_marker_col)))
224-
)
225-
sample_df = (
226-
input_df.filter((col(sample_marker_col) == 1))
227-
.groupBy([period_col, strata_col])
228-
.agg(count(col(sample_marker_col)))
229-
)
230-
if (
231-
death_df.join(sample_df, ["period", "strata"], "left")
232-
.fillna(0, ["count(sample_inclusion_marker)"])
233-
.filter(
234-
(col("count(death_marker)")) > (col("count(sample_inclusion_marker)"))
235-
)
219+
if (
220+
death_marker_col is not None
221+
and (
222+
input_df.groupBy([period_col, strata_col])
223+
.agg(sum(col(death_marker_col)), sum(col(sample_marker_col)))
224+
.fillna(0, ["sum(sample_inclusion_marker)"])
225+
.filter(col("sum(death_marker)") > col("sum(sample_inclusion_marker)"))
236226
.count()
237-
) >= 1:
238-
raise ValidationError("The death count must be less than sample count.")
227+
)
228+
>= 1
229+
):
230+
raise ValidationError("The death count must be less than sample count.")
239231

240232
# --- prepare our working data frame ---
241233
col_list = [

0 commit comments

Comments
 (0)
Please sign in to comment.