Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add max_part_size parameter to MultiPartParser #2815

Merged
merged 11 commits into from
Dec 28, 2024
7 changes: 4 additions & 3 deletions starlette/formparsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,16 +122,15 @@ async def parse(self) -> FormData:


class MultiPartParser:
max_file_size = 1024 * 1024 # 1MB
max_part_size = 1024 * 1024 # 1MB

def __init__(
self,
headers: Headers,
stream: typing.AsyncGenerator[bytes, None],
*,
max_files: int | float = 1000,
max_fields: int | float = 1000,
max_file_size: int = 1024 * 1024, # 1MB
max_part_size: int | float = 1024 * 1024, # 1MB
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we remove defaults here and keep the default at top most function?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is okay.

) -> None:
assert multipart is not None, "The `python-multipart` library must be installed to use form parsing."
self.headers = headers
Expand All @@ -148,6 +147,8 @@ def __init__(
self._file_parts_to_write: list[tuple[MultipartPart, bytes]] = []
self._file_parts_to_finish: list[MultipartPart] = []
self._files_to_close_on_error: list[SpooledTemporaryFile[bytes]] = []
self.max_file_size = max_file_size
self.max_part_size = max_part_size

def on_part_begin(self) -> None:
self._current_part = MultipartPart()
Expand Down
24 changes: 21 additions & 3 deletions starlette/requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,14 @@ async def json(self) -> typing.Any:
self._json = json.loads(body)
return self._json

async def _get_form(self, *, max_files: int | float = 1000, max_fields: int | float = 1000) -> FormData:
async def _get_form(
self,
*,
max_files: int | float = 1000,
max_fields: int | float = 1000,
max_file_size: int = 1024 * 1024,
max_part_size: int | float = 1024 * 1024,
) -> FormData:
if self._form is None:
assert (
parse_options_header is not None
Expand All @@ -264,6 +271,8 @@ async def _get_form(self, *, max_files: int | float = 1000, max_fields: int | fl
self.stream(),
max_files=max_files,
max_fields=max_fields,
max_file_size=max_file_size,
max_part_size=max_part_size,
)
self._form = await multipart_parser.parse()
except MultiPartException as exc:
Expand All @@ -278,9 +287,18 @@ async def _get_form(self, *, max_files: int | float = 1000, max_fields: int | fl
return self._form

def form(
self, *, max_files: int | float = 1000, max_fields: int | float = 1000
self,
*,
max_files: int | float = 1000,
max_fields: int | float = 1000,
max_file_size: int = 1024 * 1024,
max_part_size: int | float = 1024 * 1024,
) -> AwaitableOrContextManager[FormData]:
return AwaitableOrContextManagerWrapper(self._get_form(max_files=max_files, max_fields=max_fields))
return AwaitableOrContextManagerWrapper(
self._get_form(
max_files=max_files, max_fields=max_fields, max_file_size=max_file_size, max_part_size=max_part_size
)
)

async def close(self) -> None:
if self._form is not None: # pragma: no branch
Expand Down
48 changes: 46 additions & 2 deletions tests/test_formparsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,14 @@ async def app_read_body(scope: Scope, receive: Receive, send: Send) -> None:
await response(scope, receive, send)


def make_app_max_parts(max_files: int = 1000, max_fields: int = 1000) -> ASGIApp:
def make_app_max_parts(
max_files: int = 1000, max_fields: int = 1000, max_file_size: int = 1024 * 1024, max_part_size: int = 1024 * 1024
) -> ASGIApp:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
request = Request(scope, receive)
data = await request.form(max_files=max_files, max_fields=max_fields)
data = await request.form(
max_files=max_files, max_fields=max_fields, max_file_size=max_file_size, max_part_size=max_part_size
)
output: dict[str, typing.Any] = {}
for key, value in data.items():
if isinstance(value, UploadFile):
Expand Down Expand Up @@ -699,3 +703,43 @@ def test_max_part_size_exceeds_limit(
response = client.post("/", data=multipart_data, headers=headers) # type: ignore
assert response.status_code == 400
assert response.text == "Part exceeded maximum size of 1024KB."


@pytest.mark.parametrize(
"app,expectation",
[
(make_app_max_parts(max_part_size=1024 * 10), pytest.raises(MultiPartException)),
(
Starlette(routes=[Mount("/", app=make_app_max_parts(max_part_size=1024 * 10))]),
does_not_raise(),
),
],
)
def test_max_part_size_exceeds_custom_limit(
app: ASGIApp,
expectation: typing.ContextManager[Exception],
test_client_factory: TestClientFactory,
) -> None:
client = test_client_factory(app)
boundary = "------------------------4K1ON9fZkj9uCUmqLHRbbR"

multipart_data = (
f"--{boundary}\r\n"
f'Content-Disposition: form-data; name="small"\r\n\r\n'
"small content\r\n"
f"--{boundary}\r\n"
f'Content-Disposition: form-data; name="large"\r\n\r\n'
+ ("x" * 1024 * 10 + "x") # 1MB + 1 byte of data
+ "\r\n"
f"--{boundary}--\r\n"
).encode("utf-8")

headers = {
"Content-Type": f"multipart/form-data; boundary={boundary}",
"Transfer-Encoding": "chunked",
}

with expectation:
response = client.post("/", data=multipart_data, headers=headers) # type: ignore
assert response.status_code == 400
assert response.text == "Part exceeded maximum size of 10KB."
Loading