Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Commit 6dade80

Browse files
authored
Combine the CAS & SAML implementations for required attributes. (#9326)
1 parent 80d6dc9 commit 6dade80

File tree

9 files changed

+245
-77
lines changed

9 files changed

+245
-77
lines changed

changelog.d/9326.misc

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Share the code for handling required attributes between the CAS and SAML handlers.

synapse/config/cas.py

+30-2
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,12 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16+
from typing import Any, List
17+
18+
from synapse.config.sso import SsoAttributeRequirement
19+
1620
from ._base import Config
21+
from ._util import validate_config
1722

1823

1924
class CasConfig(Config):
@@ -38,12 +43,16 @@ def read_config(self, config, **kwargs):
3843
public_base_url + "_matrix/client/r0/login/cas/ticket"
3944
)
4045
self.cas_displayname_attribute = cas_config.get("displayname_attribute")
41-
self.cas_required_attributes = cas_config.get("required_attributes") or {}
46+
required_attributes = cas_config.get("required_attributes") or {}
47+
self.cas_required_attributes = _parsed_required_attributes_def(
48+
required_attributes
49+
)
50+
4251
else:
4352
self.cas_server_url = None
4453
self.cas_service_url = None
4554
self.cas_displayname_attribute = None
46-
self.cas_required_attributes = {}
55+
self.cas_required_attributes = []
4756

4857
def generate_config_section(self, config_dir_path, server_name, **kwargs):
4958
return """\
@@ -75,3 +84,22 @@ def generate_config_section(self, config_dir_path, server_name, **kwargs):
7584
# userGroup: "staff"
7685
# department: None
7786
"""
87+
88+
89+
# CAS uses a legacy required attributes mapping, not the one provided by
90+
# SsoAttributeRequirement.
91+
REQUIRED_ATTRIBUTES_SCHEMA = {
92+
"type": "object",
93+
"additionalProperties": {"anyOf": [{"type": "string"}, {"type": "null"}]},
94+
}
95+
96+
97+
def _parsed_required_attributes_def(
98+
required_attributes: Any,
99+
) -> List[SsoAttributeRequirement]:
100+
validate_config(
101+
REQUIRED_ATTRIBUTES_SCHEMA,
102+
required_attributes,
103+
config_path=("cas_config", "required_attributes"),
104+
)
105+
return [SsoAttributeRequirement(k, v) for k, v in required_attributes.items()]

synapse/config/saml2_config.py

+5-20
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,7 @@
1717
import logging
1818
from typing import Any, List
1919

20-
import attr
21-
20+
from synapse.config.sso import SsoAttributeRequirement
2221
from synapse.python_dependencies import DependencyException, check_requirements
2322
from synapse.util.module_loader import load_module, load_python_module
2423

@@ -396,32 +395,18 @@ def generate_config_section(self, config_dir_path, server_name, **kwargs):
396395
}
397396

398397

399-
@attr.s(frozen=True)
400-
class SamlAttributeRequirement:
401-
"""Object describing a single requirement for SAML attributes."""
402-
403-
attribute = attr.ib(type=str)
404-
value = attr.ib(type=str)
405-
406-
JSON_SCHEMA = {
407-
"type": "object",
408-
"properties": {"attribute": {"type": "string"}, "value": {"type": "string"}},
409-
"required": ["attribute", "value"],
410-
}
411-
412-
413398
ATTRIBUTE_REQUIREMENTS_SCHEMA = {
414399
"type": "array",
415-
"items": SamlAttributeRequirement.JSON_SCHEMA,
400+
"items": SsoAttributeRequirement.JSON_SCHEMA,
416401
}
417402

418403

419404
def _parse_attribute_requirements_def(
420405
attribute_requirements: Any,
421-
) -> List[SamlAttributeRequirement]:
406+
) -> List[SsoAttributeRequirement]:
422407
validate_config(
423408
ATTRIBUTE_REQUIREMENTS_SCHEMA,
424409
attribute_requirements,
425-
config_path=["saml2_config", "attribute_requirements"],
410+
config_path=("saml2_config", "attribute_requirements"),
426411
)
427-
return [SamlAttributeRequirement(**x) for x in attribute_requirements]
412+
return [SsoAttributeRequirement(**x) for x in attribute_requirements]

synapse/config/sso.py

+18-1
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,28 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15-
from typing import Any, Dict
15+
from typing import Any, Dict, Optional
16+
17+
import attr
1618

1719
from ._base import Config
1820

1921

22+
@attr.s(frozen=True)
23+
class SsoAttributeRequirement:
24+
"""Object describing a single requirement for SSO attributes."""
25+
26+
attribute = attr.ib(type=str)
27+
# If a value is not given, than the attribute must simply exist.
28+
value = attr.ib(type=Optional[str])
29+
30+
JSON_SCHEMA = {
31+
"type": "object",
32+
"properties": {"attribute": {"type": "string"}, "value": {"type": "string"}},
33+
"required": ["attribute", "value"],
34+
}
35+
36+
2037
class SSOConfig(Config):
2138
"""SSO Configuration
2239
"""

synapse/handlers/cas_handler.py

+11-29
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# limitations under the License.
1515
import logging
1616
import urllib.parse
17-
from typing import TYPE_CHECKING, Dict, Optional
17+
from typing import TYPE_CHECKING, Dict, List, Optional
1818
from xml.etree import ElementTree as ET
1919

2020
import attr
@@ -49,7 +49,7 @@ def __str__(self):
4949
@attr.s(slots=True, frozen=True)
5050
class CasResponse:
5151
username = attr.ib(type=str)
52-
attributes = attr.ib(type=Dict[str, Optional[str]])
52+
attributes = attr.ib(type=Dict[str, List[Optional[str]]])
5353

5454

5555
class CasHandler:
@@ -169,7 +169,7 @@ def _parse_cas_response(self, cas_response_body: bytes) -> CasResponse:
169169

170170
# Iterate through the nodes and pull out the user and any extra attributes.
171171
user = None
172-
attributes = {}
172+
attributes = {} # type: Dict[str, List[Optional[str]]]
173173
for child in root[0]:
174174
if child.tag.endswith("user"):
175175
user = child.text
@@ -182,7 +182,7 @@ def _parse_cas_response(self, cas_response_body: bytes) -> CasResponse:
182182
tag = attribute.tag
183183
if "}" in tag:
184184
tag = tag.split("}")[1]
185-
attributes[tag] = attribute.text
185+
attributes.setdefault(tag, []).append(attribute.text)
186186

187187
# Ensure a user was found.
188188
if user is None:
@@ -303,29 +303,10 @@ async def _handle_cas_response(
303303

304304
# Ensure that the attributes of the logged in user meet the required
305305
# attributes.
306-
for required_attribute, required_value in self._cas_required_attributes.items():
307-
# If required attribute was not in CAS Response - Forbidden
308-
if required_attribute not in cas_response.attributes:
309-
self._sso_handler.render_error(
310-
request,
311-
"unauthorised",
312-
"You are not authorised to log in here.",
313-
401,
314-
)
315-
return
316-
317-
# Also need to check value
318-
if required_value is not None:
319-
actual_value = cas_response.attributes[required_attribute]
320-
# If required attribute value does not match expected - Forbidden
321-
if required_value != actual_value:
322-
self._sso_handler.render_error(
323-
request,
324-
"unauthorised",
325-
"You are not authorised to log in here.",
326-
401,
327-
)
328-
return
306+
if not self._sso_handler.check_required_attributes(
307+
request, cas_response.attributes, self._cas_required_attributes
308+
):
309+
return
329310

330311
# Call the mapper to register/login the user
331312

@@ -372,9 +353,10 @@ async def cas_response_to_user_attributes(failures: int) -> UserAttributes:
372353
if failures:
373354
raise RuntimeError("CAS is not expected to de-duplicate Matrix IDs")
374355

356+
# Arbitrarily use the first attribute found.
375357
display_name = cas_response.attributes.get(
376-
self._cas_displayname_attribute, None
377-
)
358+
self._cas_displayname_attribute, [None]
359+
)[0]
378360

379361
return UserAttributes(localpart=localpart, display_name=display_name)
380362

synapse/handlers/saml_handler.py

+4-22
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323

2424
from synapse.api.errors import SynapseError
2525
from synapse.config import ConfigError
26-
from synapse.config.saml2_config import SamlAttributeRequirement
2726
from synapse.handlers._base import BaseHandler
2827
from synapse.handlers.sso import MappingException, UserAttributes
2928
from synapse.http.servlet import parse_string
@@ -239,12 +238,10 @@ async def _handle_authn_response(
239238

240239
# Ensure that the attributes of the logged in user meet the required
241240
# attributes.
242-
for requirement in self._saml2_attribute_requirements:
243-
if not _check_attribute_requirement(saml2_auth.ava, requirement):
244-
self._sso_handler.render_error(
245-
request, "unauthorised", "You are not authorised to log in here."
246-
)
247-
return
241+
if not self._sso_handler.check_required_attributes(
242+
request, saml2_auth.ava, self._saml2_attribute_requirements
243+
):
244+
return
248245

249246
# Call the mapper to register/login the user
250247
try:
@@ -373,21 +370,6 @@ def expire_sessions(self):
373370
del self._outstanding_requests_dict[reqid]
374371

375372

376-
def _check_attribute_requirement(ava: dict, req: SamlAttributeRequirement) -> bool:
377-
values = ava.get(req.attribute, [])
378-
for v in values:
379-
if v == req.value:
380-
return True
381-
382-
logger.info(
383-
"SAML2 attribute %s did not match required value '%s' (was '%s')",
384-
req.attribute,
385-
req.value,
386-
values,
387-
)
388-
return False
389-
390-
391373
DOT_REPLACE_PATTERN = re.compile(
392374
("[^%s]" % (re.escape("".join(mxid_localpart_allowed_characters)),))
393375
)

synapse/handlers/sso.py

+71
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,12 @@
1616
import logging
1717
from typing import (
1818
TYPE_CHECKING,
19+
Any,
1920
Awaitable,
2021
Callable,
2122
Dict,
2223
Iterable,
24+
List,
2325
Mapping,
2426
Optional,
2527
Set,
@@ -34,6 +36,7 @@
3436

3537
from synapse.api.constants import LoginType
3638
from synapse.api.errors import Codes, NotFoundError, RedirectException, SynapseError
39+
from synapse.config.sso import SsoAttributeRequirement
3740
from synapse.handlers.ui_auth import UIAuthSessionDataConstants
3841
from synapse.http import get_request_user_agent
3942
from synapse.http.server import respond_with_html, respond_with_redirect
@@ -893,6 +896,41 @@ def _expire_old_sessions(self):
893896
logger.info("Expiring mapping session %s", session_id)
894897
del self._username_mapping_sessions[session_id]
895898

899+
def check_required_attributes(
900+
self,
901+
request: SynapseRequest,
902+
attributes: Mapping[str, List[Any]],
903+
attribute_requirements: Iterable[SsoAttributeRequirement],
904+
) -> bool:
905+
"""
906+
Confirm that the required attributes were present in the SSO response.
907+
908+
If all requirements are met, this will return True.
909+
910+
If any requirement is not met, then the request will be finalized by
911+
showing an error page to the user and False will be returned.
912+
913+
Args:
914+
request: The request to (potentially) respond to.
915+
attributes: The attributes from the SSO IdP.
916+
attribute_requirements: The requirements that attributes must meet.
917+
918+
Returns:
919+
True if all requirements are met, False if any attribute fails to
920+
meet the requirement.
921+
922+
"""
923+
# Ensure that the attributes of the logged in user meet the required
924+
# attributes.
925+
for requirement in attribute_requirements:
926+
if not _check_attribute_requirement(attributes, requirement):
927+
self.render_error(
928+
request, "unauthorised", "You are not authorised to log in here."
929+
)
930+
return False
931+
932+
return True
933+
896934

897935
def get_username_mapping_session_cookie_from_request(request: IRequest) -> str:
898936
"""Extract the session ID from the cookie
@@ -903,3 +941,36 @@ def get_username_mapping_session_cookie_from_request(request: IRequest) -> str:
903941
if not session_id:
904942
raise SynapseError(code=400, msg="missing session_id")
905943
return session_id.decode("ascii", errors="replace")
944+
945+
946+
def _check_attribute_requirement(
947+
attributes: Mapping[str, List[Any]], req: SsoAttributeRequirement
948+
) -> bool:
949+
"""Check if SSO attributes meet the proper requirements.
950+
951+
Args:
952+
attributes: A mapping of attributes to an iterable of one or more values.
953+
requirement: The configured requirement to check.
954+
955+
Returns:
956+
True if the required attribute was found and had a proper value.
957+
"""
958+
if req.attribute not in attributes:
959+
logger.info("SSO attribute missing: %s", req.attribute)
960+
return False
961+
962+
# If the requirement is None, the attribute existing is enough.
963+
if req.value is None:
964+
return True
965+
966+
values = attributes[req.attribute]
967+
if req.value in values:
968+
return True
969+
970+
logger.info(
971+
"SSO attribute %s did not match required value '%s' (was '%s')",
972+
req.attribute,
973+
req.value,
974+
values,
975+
)
976+
return False

0 commit comments

Comments
 (0)