Skip to content

Commit beb802a

Browse files
authored
chore[python]: Enforce strict type equality in mypy check
1 parent 7e721d0 commit beb802a

12 files changed

+53
-41
lines changed

py-polars/pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ disallow_untyped_calls = true
4141
warn_redundant_casts = true
4242
# warn_return_any = true
4343
no_implicit_reexport = true
44-
# strict_equality = true
44+
strict_equality = true
4545
# TODO: When all flags are enabled, replace by strict = true
4646
enable_error_code = [
4747
"redundant-expr",

py-polars/tests/db-benchmark/various.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# may contain many things that seemed to go wrong at scale
22

33
import time
4+
from typing import cast
45

56
import numpy as np
67

@@ -50,8 +51,8 @@
5051
computed = permuted.select(
5152
[pl.col("id").min().alias("min"), pl.col("id").max().alias("max")]
5253
)
53-
assert computed[0, "min"] == minimum
54-
assert computed[0, "max"] == maximum
54+
assert cast(int, computed[0, "min"]) == minimum
55+
assert cast(float, computed[0, "max"]) == maximum
5556

5657

5758
def test_windows_not_cached() -> None:

py-polars/tests/io/test_csv.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import zlib
88
from datetime import date, datetime, time
99
from pathlib import Path
10+
from typing import cast
1011

1112
import pytest
1213

@@ -402,8 +403,8 @@ def test_csv_globbing(examples_dir: str) -> None:
402403

403404
df = pl.read_csv(path, columns=["category", "sugars_g"])
404405
assert df.shape == (135, 2)
405-
assert df.row(-1) == ("seafood", 1)
406-
assert df.row(0) == ("vegetables", 2)
406+
assert df.row(-1) == ("seafood", 1) # type: ignore[comparison-overlap]
407+
assert df.row(0) == ("vegetables", 2) # type: ignore[comparison-overlap]
407408

408409
with pytest.raises(ValueError):
409410
_ = pl.read_csv(path, dtypes=[pl.Utf8, pl.Int64, pl.Int64, pl.Int64])
@@ -509,7 +510,10 @@ def test_fallback_chrono_parser() -> None:
509510
2021-10-10,2021-10-10
510511
"""
511512
)
512-
assert pl.read_csv(data.encode(), parse_dates=True).null_count().row(0) == (0, 0)
513+
assert cast(
514+
tuple[int, int],
515+
pl.read_csv(data.encode(), parse_dates=True).null_count().row(0),
516+
) == (0, 0)
513517

514518

515519
def test_csv_string_escaping() -> None:

py-polars/tests/test_constructors.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@ def test_init_dict() -> None:
2929

3030
# List of empty list/tuple
3131
df = pl.DataFrame({"a": [[]], "b": [()]})
32-
assert df.schema == {"a": pl.List(pl.Float64), "b": pl.List(pl.Float64)}
32+
expected = {"a": pl.List(pl.Float64), "b": pl.List(pl.Float64)}
33+
assert df.schema == expected # type: ignore[comparison-overlap]
3334
assert df.rows() == [([], [])]
3435

3536
# Mixed dtypes

py-polars/tests/test_datatypes.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ def test_dtype_init_equivalence() -> None:
1414
if inspect.isclass(dtype) and issubclass(dtype, datatypes.DataType)
1515
}
1616
for dtype in all_datatypes:
17-
assert dtype == dtype()
17+
assert dtype == dtype() # type: ignore[comparison-overlap]
1818

1919

2020
def test_dtype_temporal_units() -> None:

py-polars/tests/test_datelike.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
from __future__ import annotations
22

33
import io
4-
import typing
54
from datetime import date, datetime, time, timedelta
6-
from typing import TYPE_CHECKING
5+
from typing import TYPE_CHECKING, no_type_check
76

87
import numpy as np
98
import pandas as pd
@@ -154,7 +153,7 @@ def test_datetime_consistency() -> None:
154153
pl.lit(dt).cast(pl.Datetime("ns")).alias("dt_ns"),
155154
]
156155
)
157-
assert ddf.schema == {
156+
assert ddf.schema == { # type: ignore[comparison-overlap]
158157
"date": pl.Datetime("us"),
159158
"dt": pl.Datetime("us"),
160159
"dt_ms": pl.Datetime("ms"),
@@ -886,7 +885,7 @@ def test_agg_logical() -> None:
886885
assert s.min() == dates[0]
887886

888887

889-
@typing.no_type_check
888+
@no_type_check
890889
def test_from_time_arrow() -> None:
891890
times = pa.array([10, 20, 30], type=pa.time32("s"))
892891
times_table = pa.table([times], names=["times"])
@@ -1027,7 +1026,8 @@ def test_datetime_instance_selection() -> None:
10271026
],
10281027
)
10291028
for tu in DTYPE_TEMPORAL_UNITS:
1030-
assert df.select(pl.col([pl.Datetime(tu)])).dtypes == [pl.Datetime(tu)]
1029+
res = df.select(pl.col([pl.Datetime(tu)])).dtypes
1030+
assert res == [pl.Datetime(tu)] # type: ignore[comparison-overlap]
10311031
assert len(df.filter(pl.col(tu) == test_data[tu][0])) == 1
10321032

10331033

py-polars/tests/test_df.py

+13-13
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def test_selection() -> None:
9090

9191
# select columns by mask
9292
assert df[:2, :1].shape == (2, 1)
93-
assert df[:2, "a"].shape == (2, 1)
93+
assert df[:2, "a"].shape == (2, 1) # type: ignore[comparison-overlap]
9494

9595
# column selection by string(s) in first dimension
9696
assert df["a"].to_list() == [1, 2, 3]
@@ -117,11 +117,11 @@ def test_selection() -> None:
117117
assert df[[1, 2], [1, 2]].frame_equal(
118118
pl.DataFrame({"b": [2.0, 3.0], "c": ["b", "c"]})
119119
)
120-
assert df[1, 2] == "b"
121-
assert df[1, 1] == 2.0
122-
assert df[2, 0] == 3
120+
assert typing.cast(str, df[1, 2]) == "b"
121+
assert typing.cast(float, df[1, 1]) == 2.0
122+
assert typing.cast(int, df[2, 0]) == 3
123123

124-
assert df[[0, 1], "b"].shape == (2, 1)
124+
assert df[[0, 1], "b"].shape == (2, 1) # type: ignore[comparison-overlap]
125125
assert df[[2], ["a", "b"]].shape == (1, 2)
126126
assert df.to_series(0).name == "a"
127127
assert (df["a"] == df["a"]).sum() == 3
@@ -132,10 +132,10 @@ def test_selection() -> None:
132132
assert df[1, [2]].frame_equal(expect)
133133
expect = pl.DataFrame({"b": [1.0, 3.0]})
134134
assert df[[0, 2], [1]].frame_equal(expect)
135-
assert df[0, "c"] == "a"
136-
assert df[1, "c"] == "b"
137-
assert df[2, "c"] == "c"
138-
assert df[0, "a"] == 1
135+
assert typing.cast(str, df[0, "c"]) == "a"
136+
assert typing.cast(str, df[1, "c"]) == "b"
137+
assert typing.cast(str, df[2, "c"]) == "c"
138+
assert typing.cast(int, df[0, "a"]) == 1
139139

140140
# more slicing
141141
expect = pl.DataFrame({"a": [3, 2, 1], "b": [3.0, 2.0, 1.0], "c": ["c", "b", "a"]})
@@ -766,9 +766,9 @@ def test_df_fold() -> None:
766766

767767
def test_row_tuple() -> None:
768768
df = pl.DataFrame({"a": ["foo", "bar", "2"], "b": [1, 2, 3], "c": [1.0, 2.0, 3.0]})
769-
assert df.row(0) == ("foo", 1, 1.0)
770-
assert df.row(1) == ("bar", 2, 2.0)
771-
assert df.row(-1) == ("2", 3, 3.0)
769+
assert df.row(0) == ("foo", 1, 1.0) # type: ignore[comparison-overlap]
770+
assert df.row(1) == ("bar", 2, 2.0) # type: ignore[comparison-overlap]
771+
assert df.row(-1) == ("2", 3, 3.0) # type: ignore[comparison-overlap]
772772

773773

774774
def test_df_apply() -> None:
@@ -1058,7 +1058,7 @@ def dot_product() -> None:
10581058
df = pl.DataFrame({"a": [1, 2, 3, 4], "b": [2, 2, 2, 2]})
10591059

10601060
assert df["a"].dot(df["b"]) == 20
1061-
assert df.select([pl.col("a").dot("b")])[0, "a"] == 20
1061+
assert typing.cast(int, df.select([pl.col("a").dot("b")])[0, "a"]) == 20
10621062

10631063

10641064
def test_hash_rows() -> None:

py-polars/tests/test_exprs.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from __future__ import annotations
22

3+
from typing import cast
4+
35
import polars as pl
46
from polars.testing import assert_series_equal, verify_series_and_expr_api
57

@@ -82,7 +84,7 @@ def test_count_expr() -> None:
8284

8385
out = df.select(pl.count())
8486
assert out.shape == (1, 1)
85-
assert out[0, 0] == 5
87+
assert cast(int, out[0, 0]) == 5
8688

8789
out = df.groupby("b", maintain_order=True).agg(pl.count())
8890
assert out["b"].to_list() == ["a", "b"]
@@ -274,9 +276,11 @@ def test_regex_in_filter() -> None:
274276
}
275277
)
276278

277-
assert df.filter(
279+
res = df.filter(
278280
pl.fold(acc=False, f=lambda acc, s: acc | s, exprs=(pl.col("^nrs|flt*$") < 3))
279-
).row(0) == (1, "foo", 1.0)
281+
).row(0)
282+
expected = (1, "foo", 1.0)
283+
assert res == expected # type: ignore[comparison-overlap]
280284

281285

282286
def test_arr_contains() -> None:

py-polars/tests/test_interop.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -388,7 +388,7 @@ def test_from_pandas_ns_resolution() -> None:
388388
[pd.Timestamp(year=2021, month=1, day=1, hour=1, second=1, nanosecond=1)],
389389
columns=["date"],
390390
)
391-
assert pl.from_pandas(df)[0, 0] == datetime(2021, 1, 1, 1, 0, 1)
391+
assert cast(datetime, pl.from_pandas(df)[0, 0]) == datetime(2021, 1, 1, 1, 0, 1)
392392

393393

394394
@no_type_check

py-polars/tests/test_lazy.py

+11-9
Original file line numberDiff line numberDiff line change
@@ -472,7 +472,7 @@ def test_is_finite_is_infinite() -> None:
472472

473473
def test_len() -> None:
474474
df = pl.DataFrame({"nrs": [1, 2, 3]})
475-
assert df.select(col("nrs").len())[0, 0] == 3
475+
assert cast(int, df.select(col("nrs").len())[0, 0]) == 3
476476

477477

478478
def test_cum_agg() -> None:
@@ -505,7 +505,7 @@ def test_round() -> None:
505505

506506
def test_dot() -> None:
507507
df = pl.DataFrame({"a": [1.8, 1.2, 3.0], "b": [3.2, 1, 2]})
508-
assert df.select(pl.col("a").dot(pl.col("b")))[0, 0] == 12.96
508+
assert cast(float, df.select(pl.col("a").dot(pl.col("b")))[0, 0]) == 12.96
509509

510510

511511
def test_sort() -> None:
@@ -696,8 +696,8 @@ def test_rolling(fruits_cars: pl.DataFrame) -> None:
696696
]
697697
)
698698

699-
assert out_single_val_variance[0, "std"] == 0.0
700-
assert out_single_val_variance[0, "var"] == 0.0
699+
assert cast(float, out_single_val_variance[0, "std"]) == 0.0
700+
assert cast(float, out_single_val_variance[0, "var"]) == 0.0
701701

702702

703703
def test_rolling_apply() -> None:
@@ -993,16 +993,16 @@ def test_join_suffix() -> None:
993993
def test_str_concat() -> None:
994994
df = pl.DataFrame({"foo": [1, None, 2]})
995995
df = df.select(pl.col("foo").str.concat("-"))
996-
assert df[0, 0] == "1-null-2"
996+
assert cast(str, df[0, 0]) == "1-null-2"
997997

998998

999999
@pytest.mark.parametrize("no_optimization", [False, True])
10001000
def test_collect_all(df: pl.DataFrame, no_optimization: bool) -> None:
10011001
lf1 = df.lazy().select(pl.col("int").sum())
10021002
lf2 = df.lazy().select((pl.col("floats") * 2).sum())
10031003
out = pl.collect_all([lf1, lf2], no_optimization=no_optimization)
1004-
assert out[0][0, 0] == 6
1005-
assert out[1][0, 0] == 12.0
1004+
assert cast(int, out[0][0, 0]) == 6
1005+
assert cast(float, out[1][0, 0]) == 12.0
10061006

10071007

10081008
def test_spearman_corr() -> None:
@@ -1058,8 +1058,10 @@ def test_pearson_corr() -> None:
10581058

10591059

10601060
def test_cov(fruits_cars: pl.DataFrame) -> None:
1061-
assert fruits_cars.select(pl.cov("A", "B"))[0, 0] == -2.5
1062-
assert fruits_cars.select(pl.cov(pl.col("A"), pl.col("B")))[0, 0] == -2.5
1061+
assert cast(float, fruits_cars.select(pl.cov("A", "B"))[0, 0]) == -2.5
1062+
assert (
1063+
cast(float, fruits_cars.select(pl.cov(pl.col("A"), pl.col("B")))[0, 0]) == -2.5
1064+
)
10631065

10641066

10651067
def test_std(fruits_cars: pl.DataFrame) -> None:

py-polars/tests/test_lists.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def test_dtype() -> None:
7878
("dtm", pl.List(pl.Datetime)),
7979
],
8080
)
81-
assert df.schema == {
81+
assert df.schema == { # type: ignore[comparison-overlap]
8282
"i": pl.List(pl.Int8),
8383
"tm": pl.List(pl.Time),
8484
"dt": pl.List(pl.Date),

py-polars/tests/test_queries.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -316,7 +316,7 @@ def test_when_then_edge_cases_3994() -> None:
316316
.groupby(["id"])
317317
.agg(pl.col("type"))
318318
.with_column(
319-
pl.when(pl.col("type").arr.lengths == 0)
319+
pl.when(pl.col("type").arr.lengths() == 0)
320320
.then(pl.lit(None))
321321
.otherwise(pl.col("type"))
322322
.keep_name()

0 commit comments

Comments
 (0)