|
17 | 17 | import os
|
18 | 18 | import re
|
19 | 19 | from typing import Any, Dict, Optional, Sequence, Tuple, Type
|
20 |
| -from urllib.parse import urlencode |
| 20 | +from urllib.parse import quote, urlencode |
21 | 21 |
|
22 | 22 | from twisted.internet._resolver import HostResolution
|
23 | 23 | from twisted.internet.address import IPv4Address, IPv6Address
|
@@ -69,7 +69,6 @@ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
69 | 69 | "2001:800::/21",
|
70 | 70 | )
|
71 | 71 | config["url_preview_ip_range_whitelist"] = ("1.1.1.1",)
|
72 |
| - config["url_preview_url_blacklist"] = [] |
73 | 72 | config["url_preview_accept_language"] = [
|
74 | 73 | "en-UK",
|
75 | 74 | "en-US;q=0.9",
|
@@ -1123,3 +1122,43 @@ def test_cache_expiry(self) -> None:
|
1123 | 1122 | os.path.exists(path),
|
1124 | 1123 | f"{os.path.relpath(path, self.media_store_path)} was not deleted",
|
1125 | 1124 | )
|
| 1125 | + |
| 1126 | + @unittest.override_config({"url_preview_url_blacklist": [{"port": "*"}]}) |
| 1127 | + def test_blacklist_port(self) -> None: |
| 1128 | + """Tests that blacklisting URLs with a port makes previewing such URLs |
| 1129 | + fail with a 403 error and doesn't impact other previews. |
| 1130 | + """ |
| 1131 | + self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")] |
| 1132 | + |
| 1133 | + bad_url = quote("http://matrix.org:8888/foo") |
| 1134 | + good_url = quote("http://matrix.org/foo") |
| 1135 | + |
| 1136 | + channel = self.make_request( |
| 1137 | + "GET", |
| 1138 | + "preview_url?url=" + bad_url, |
| 1139 | + shorthand=False, |
| 1140 | + await_result=False, |
| 1141 | + ) |
| 1142 | + self.pump() |
| 1143 | + self.assertEqual(channel.code, 403, channel.result) |
| 1144 | + |
| 1145 | + channel = self.make_request( |
| 1146 | + "GET", |
| 1147 | + "preview_url?url=" + good_url, |
| 1148 | + shorthand=False, |
| 1149 | + await_result=False, |
| 1150 | + ) |
| 1151 | + self.pump() |
| 1152 | + |
| 1153 | + client = self.reactor.tcpClients[0][2].buildProtocol(None) |
| 1154 | + server = AccumulatingProtocol() |
| 1155 | + server.makeConnection(FakeTransport(client, self.reactor)) |
| 1156 | + client.makeConnection(FakeTransport(server, self.reactor)) |
| 1157 | + client.dataReceived( |
| 1158 | + b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\nContent-Type: text/html\r\n\r\n" |
| 1159 | + % (len(self.end_content),) |
| 1160 | + + self.end_content |
| 1161 | + ) |
| 1162 | + |
| 1163 | + self.pump() |
| 1164 | + self.assertEqual(channel.code, 200) |
0 commit comments