Skip to content

Commit 0f71731

Browse files
committed
Refactor the way how Payload headers are handled
This change actually solves three separated, but heavy coupled issues: 1. `Payload.content_type` may conflict with `Payload.headers[CONTENT_TYPE]`. While in the end priority goes to the former one, it seems quite strange that Payload object may have dual state about what content type it contains. 2.IOPayload respects Content-Disposition which comes with headers. 3. ClientRequest.skip_autoheaders now filters Payload.headers as well. This issue was eventually found due to refactoring: Payload object may setup some autoheaders, but those will bypass skip logic.
1 parent 0715ae7 commit 0f71731

File tree

4 files changed

+136
-68
lines changed

4 files changed

+136
-68
lines changed

aiohttp/client_reqrep.py

+7-9
Original file line numberDiff line numberDiff line change
@@ -495,16 +495,14 @@ def update_body_from_data(self, body: Any) -> None:
495495
if hdrs.CONTENT_LENGTH not in self.headers:
496496
self.headers[hdrs.CONTENT_LENGTH] = str(size)
497497

498-
# set content-type
499-
if (hdrs.CONTENT_TYPE not in self.headers and
500-
hdrs.CONTENT_TYPE not in self.skip_auto_headers):
501-
self.headers[hdrs.CONTENT_TYPE] = body.content_type
502-
503498
# copy payload headers
504-
if body.headers:
505-
for (key, value) in body.headers.items():
506-
if key not in self.headers:
507-
self.headers[key] = value
499+
assert body.headers
500+
for (key, value) in body.headers.items():
501+
if key in self.headers:
502+
continue
503+
if key in self.skip_auto_headers:
504+
continue
505+
self.headers[key] = value
508506

509507
def update_expect_continue(self, expect: bool=False) -> None:
510508
if expect:

aiohttp/multipart.py

+5-18
Original file line numberDiff line numberDiff line change
@@ -702,9 +702,6 @@ def __init__(self, subtype: str='mixed',
702702
super().__init__(None, content_type=ctype)
703703

704704
self._parts = [] # type: List[_Part] # noqa
705-
self._headers = CIMultiDict() # type: CIMultiDict[str]
706-
assert self.content_type is not None
707-
self._headers[CONTENT_TYPE] = self.content_type
708705

709706
def __enter__(self) -> 'MultipartWriter':
710707
return self
@@ -769,28 +766,18 @@ def append(
769766
headers = CIMultiDict()
770767

771768
if isinstance(obj, Payload):
772-
if obj.headers is not None:
773-
obj.headers.update(headers)
774-
else:
775-
if isinstance(headers, CIMultiDict):
776-
obj._headers = headers
777-
else:
778-
obj._headers = CIMultiDict(headers)
769+
obj.headers.update(headers)
779770
return self.append_payload(obj)
780771
else:
781772
try:
782-
return self.append_payload(get_payload(obj, headers=headers))
773+
payload = get_payload(obj, headers=headers)
783774
except LookupError:
784-
raise TypeError
775+
raise TypeError('Cannot create payload from %r' % obj)
776+
else:
777+
return self.append_payload(payload)
785778

786779
def append_payload(self, payload: Payload) -> Payload:
787780
"""Adds a new body part to multipart writer."""
788-
# content-type
789-
assert payload.headers is not None
790-
if CONTENT_TYPE not in payload.headers:
791-
assert payload.content_type is not None
792-
payload.headers[CONTENT_TYPE] = payload.content_type
793-
794781
# compression
795782
encoding = payload.headers.get(CONTENT_ENCODING, '').lower() # type: Optional[str] # noqa
796783
if encoding and encoding not in ('deflate', 'gzip', 'identity'):

aiohttp/payload.py

+25-31
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,13 @@
77
import warnings
88
from abc import ABC, abstractmethod
99
from itertools import chain
10-
from typing import ( # noqa
10+
from typing import (
1111
IO,
1212
TYPE_CHECKING,
1313
Any,
1414
ByteString,
15-
Callable,
1615
Dict,
1716
Iterable,
18-
List,
1917
Optional,
2018
Text,
2119
TextIO,
@@ -47,6 +45,10 @@
4745
TOO_LARGE_BYTES_BODY = 2 ** 20 # 1 MB
4846

4947

48+
if TYPE_CHECKING: # pragma: no cover
49+
from typing import List # noqa
50+
51+
5052
class LookupError(Exception):
5153
pass
5254

@@ -120,9 +122,8 @@ def register(self,
120122

121123
class Payload(ABC):
122124

125+
_default_content_type = 'application/octet-stream' # type: str
123126
_size = None # type: Optional[float]
124-
_headers = None # type: Optional[_CIMultiDict]
125-
_content_type = 'application/octet-stream' # type: Optional[str]
126127

127128
def __init__(self,
128129
value: Any,
@@ -137,18 +138,20 @@ def __init__(self,
137138
filename: Optional[str]=None,
138139
encoding: Optional[str]=None,
139140
**kwargs: Any) -> None:
140-
self._value = value
141141
self._encoding = encoding
142142
self._filename = filename
143-
if headers is not None:
144-
self._headers = CIMultiDict(headers)
145-
if content_type is sentinel and hdrs.CONTENT_TYPE in self._headers:
146-
content_type = self._headers[hdrs.CONTENT_TYPE]
147-
148-
if content_type is sentinel:
149-
content_type = None
150-
151-
self._content_type = content_type
143+
self._headers = CIMultiDict() # type: _CIMultiDict
144+
self._value = value
145+
if content_type is not sentinel and content_type is not None:
146+
self._headers[hdrs.CONTENT_TYPE] = content_type
147+
elif self._filename is not None:
148+
content_type = mimetypes.guess_type(self._filename)[0]
149+
if content_type is None:
150+
content_type = self._default_content_type
151+
self._headers[hdrs.CONTENT_TYPE] = content_type
152+
else:
153+
self._headers[hdrs.CONTENT_TYPE] = self._default_content_type
154+
self._headers.update(headers or {})
152155

153156
@property
154157
def size(self) -> Optional[float]:
@@ -161,15 +164,12 @@ def filename(self) -> Optional[str]:
161164
return self._filename
162165

163166
@property
164-
def headers(self) -> Optional[_CIMultiDict]:
167+
def headers(self) -> _CIMultiDict:
165168
"""Custom item headers"""
166169
return self._headers
167170

168171
@property
169172
def _binary_headers(self) -> bytes:
170-
if self.headers is None:
171-
# FIXME: This case actually is unreachable.
172-
return b'' # pragma: no cover
173173
return ''.join(
174174
[k + ': ' + v + '\r\n' for k, v in self.headers.items()]
175175
).encode('utf-8') + b'\r\n'
@@ -180,24 +180,15 @@ def encoding(self) -> Optional[str]:
180180
return self._encoding
181181

182182
@property
183-
def content_type(self) -> Optional[str]:
183+
def content_type(self) -> str:
184184
"""Content type"""
185-
if self._content_type is not None:
186-
return self._content_type
187-
elif self._filename is not None:
188-
mime = mimetypes.guess_type(self._filename)[0]
189-
return 'application/octet-stream' if mime is None else mime
190-
else:
191-
return Payload._content_type
185+
return self._headers[hdrs.CONTENT_TYPE]
192186

193187
def set_content_disposition(self,
194188
disptype: str,
195189
quote_fields: bool=True,
196190
**params: Any) -> None:
197191
"""Sets ``Content-Disposition`` header."""
198-
if self._headers is None:
199-
self._headers = CIMultiDict()
200-
201192
self._headers[hdrs.CONTENT_DISPOSITION] = content_disposition_header(
202193
disptype, quote_fields=quote_fields, **params)
203194

@@ -292,7 +283,10 @@ def __init__(self,
292283
super().__init__(value, *args, **kwargs)
293284

294285
if self._filename is not None and disposition is not None:
295-
self.set_content_disposition(disposition, filename=self._filename)
286+
if hdrs.CONTENT_DISPOSITION not in self.headers:
287+
self.set_content_disposition(
288+
disposition, filename=self._filename
289+
)
296290

297291
async def write(self, writer: AbstractStreamWriter) -> None:
298292
loop = asyncio.get_event_loop()

tests/test_multipart.py

+99-10
Original file line numberDiff line numberDiff line change
@@ -854,8 +854,8 @@ async def test_writer_serialize_with_content_encoding_gzip(buf, stream,
854854
await writer.write(stream)
855855
headers, message = bytes(buf).split(b'\r\n\r\n', 1)
856856

857-
assert (b'--:\r\nContent-Encoding: gzip\r\n'
858-
b'Content-Type: text/plain; charset=utf-8' == headers)
857+
assert (b'--:\r\nContent-Type: text/plain; charset=utf-8\r\n'
858+
b'Content-Encoding: gzip' == headers)
859859

860860
decompressor = zlib.decompressobj(wbits=16+zlib.MAX_WBITS)
861861
data = decompressor.decompress(message.split(b'\r\n')[0])
@@ -869,8 +869,8 @@ async def test_writer_serialize_with_content_encoding_deflate(buf, stream,
869869
await writer.write(stream)
870870
headers, message = bytes(buf).split(b'\r\n\r\n', 1)
871871

872-
assert (b'--:\r\nContent-Encoding: deflate\r\n'
873-
b'Content-Type: text/plain; charset=utf-8' == headers)
872+
assert (b'--:\r\nContent-Type: text/plain; charset=utf-8\r\n'
873+
b'Content-Encoding: deflate' == headers)
874874

875875
thing = b'\x0b\xc9\xccMU(\xc9W\x08J\xcdI\xacP\x04\x00\r\n--:--\r\n'
876876
assert thing == message
@@ -883,8 +883,8 @@ async def test_writer_serialize_with_content_encoding_identity(buf, stream,
883883
await writer.write(stream)
884884
headers, message = bytes(buf).split(b'\r\n\r\n', 1)
885885

886-
assert (b'--:\r\nContent-Encoding: identity\r\n'
887-
b'Content-Type: application/octet-stream\r\n'
886+
assert (b'--:\r\nContent-Type: application/octet-stream\r\n'
887+
b'Content-Encoding: identity\r\n'
888888
b'Content-Length: 16' == headers)
889889

890890
assert thing == message.split(b'\r\n')[0]
@@ -902,8 +902,8 @@ async def test_writer_with_content_transfer_encoding_base64(buf, stream,
902902
await writer.write(stream)
903903
headers, message = bytes(buf).split(b'\r\n\r\n', 1)
904904

905-
assert (b'--:\r\nContent-Transfer-Encoding: base64\r\n'
906-
b'Content-Type: text/plain; charset=utf-8' ==
905+
assert (b'--:\r\nContent-Type: text/plain; charset=utf-8\r\n'
906+
b'Content-Transfer-Encoding: base64' ==
907907
headers)
908908

909909
assert b'VGltZSB0byBSZWxheCE=' == message.split(b'\r\n')[0]
@@ -916,8 +916,8 @@ async def test_writer_content_transfer_encoding_quote_printable(buf, stream,
916916
await writer.write(stream)
917917
headers, message = bytes(buf).split(b'\r\n\r\n', 1)
918918

919-
assert (b'--:\r\nContent-Transfer-Encoding: quoted-printable\r\n'
920-
b'Content-Type: text/plain; charset=utf-8' == headers)
919+
assert (b'--:\r\nContent-Type: text/plain; charset=utf-8\r\n'
920+
b'Content-Transfer-Encoding: quoted-printable' == headers)
921921

922922
assert (b'=D0=9F=D1=80=D0=B8=D0=B2=D0=B5=D1=82,'
923923
b' =D0=BC=D0=B8=D1=80!' == message.split(b'\r\n')[0])
@@ -1049,6 +1049,95 @@ async def test_write_preserves_content_disposition(
10491049
)
10501050
assert message == b'foo\r\n--:--\r\n'
10511051

1052+
async def test_preserve_content_disposition_header(self, buf, stream):
1053+
"""
1054+
https://github.com/aio-libs/aiohttp/pull/3475#issuecomment-451072381
1055+
"""
1056+
with open(__file__, 'rb') as fobj:
1057+
with aiohttp.MultipartWriter('form-data', boundary=':') as writer:
1058+
part = writer.append(
1059+
fobj,
1060+
headers={
1061+
CONTENT_DISPOSITION: 'attachments; filename="bug.py"',
1062+
CONTENT_TYPE: 'text/python',
1063+
}
1064+
)
1065+
content_length = part.size
1066+
await writer.write(stream)
1067+
1068+
assert part.headers[CONTENT_TYPE] == 'text/python'
1069+
assert part.headers[CONTENT_DISPOSITION] == (
1070+
'attachments; filename="bug.py"'
1071+
)
1072+
1073+
headers, _ = bytes(buf).split(b'\r\n\r\n', 1)
1074+
1075+
assert headers == (
1076+
b'--:\r\n'
1077+
b'Content-Type: text/python\r\n'
1078+
b'Content-Disposition: attachments; filename="bug.py"\r\n'
1079+
b'Content-Length: %s'
1080+
b'' % (str(content_length).encode(),)
1081+
)
1082+
1083+
async def test_set_content_disposition_override(self, buf, stream):
1084+
"""
1085+
https://github.com/aio-libs/aiohttp/pull/3475#issuecomment-451072381
1086+
"""
1087+
with open(__file__, 'rb') as fobj:
1088+
with aiohttp.MultipartWriter('form-data', boundary=':') as writer:
1089+
part = writer.append(
1090+
fobj,
1091+
headers={
1092+
CONTENT_DISPOSITION: 'attachments; filename="bug.py"',
1093+
CONTENT_TYPE: 'text/python',
1094+
}
1095+
)
1096+
content_length = part.size
1097+
await writer.write(stream)
1098+
1099+
assert part.headers[CONTENT_TYPE] == 'text/python'
1100+
assert part.headers[CONTENT_DISPOSITION] == (
1101+
'attachments; filename="bug.py"'
1102+
)
1103+
1104+
headers, _ = bytes(buf).split(b'\r\n\r\n', 1)
1105+
1106+
assert headers == (
1107+
b'--:\r\n'
1108+
b'Content-Type: text/python\r\n'
1109+
b'Content-Disposition: attachments; filename="bug.py"\r\n'
1110+
b'Content-Length: %s'
1111+
b'' % (str(content_length).encode(),)
1112+
)
1113+
1114+
async def test_reset_content_disposition_header(self, buf, stream):
1115+
"""
1116+
https://github.com/aio-libs/aiohttp/pull/3475#issuecomment-451072381
1117+
"""
1118+
with open(__file__, 'rb') as fobj:
1119+
with aiohttp.MultipartWriter('form-data', boundary=':') as writer:
1120+
part = writer.append(fobj)
1121+
1122+
content_length = part.size
1123+
1124+
assert CONTENT_DISPOSITION in part.headers
1125+
1126+
part.set_content_disposition('attachments', filename='bug.py')
1127+
1128+
await writer.write(stream)
1129+
1130+
headers, _ = bytes(buf).split(b'\r\n\r\n', 1)
1131+
1132+
assert headers == (
1133+
b'--:\r\n'
1134+
b'Content-Type: text/x-python\r\n'
1135+
b'Content-Disposition:'
1136+
b' attachments; filename="bug.py"; filename*=utf-8\'\'bug.py\r\n'
1137+
b'Content-Length: %s'
1138+
b'' % (str(content_length).encode(),)
1139+
)
1140+
10521141

10531142
async def test_async_for_reader() -> None:
10541143
data = [

0 commit comments

Comments
 (0)