Skip to content

Commit 42b9694

Browse files
committed
[SymForce] Add sf.clamp
Reviewers: nathan,bradley,harrison,chao,hayk Topic: sf-clamp Relative: sf-databuffer-unused GitOrigin-RevId: 16ff3f5a999951e2c1f137ef9998812f9d445a72
1 parent 718d061 commit 42b9694

File tree

2 files changed

+29
-5
lines changed

2 files changed

+29
-5
lines changed

symforce/internal/symbolic.py

+15
Original file line numberDiff line numberDiff line change
@@ -441,6 +441,21 @@ def acos_safe(x: Scalar, epsilon: Scalar = epsilon()) -> Scalar:
441441
return sympy.acos(x_safe)
442442

443443

444+
def clamp(x: sf.Scalar, min_value: sf.Scalar, max_value: sf.Scalar) -> sf.Scalar:
445+
"""
446+
Returns min_value if x < min_value
447+
Returns x if min_value < x < max_value
448+
Returns max_value if x > max_value
449+
450+
Args:
451+
x: Value to clamp between min_value and max_value
452+
min_value: Scalar of same type and units as x; minimum value to return
453+
max_value: Scalar of same type and units as x; maximum value to return. Must be greater
454+
than min_value.
455+
"""
456+
return sf.Min(max_value, sf.Max(min_value, x))
457+
458+
444459
def set_eval_on_sympify(eval_on_sympy: bool = True) -> None:
445460
"""
446461
When using the symengine backed, set whether we should eval args when converting objects to

test/symforce_custom_methods_test.py

+14-5
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,14 @@
1212

1313
class SymforceCustomMethodsTest(TestCase):
1414
"""
15-
Test the custom methods added by add_custom_methods in initialization.py
15+
Test the custom methods added by in the "custom functions" section in symbolic.py
1616
"""
1717

1818
def test_arg_maxes(self) -> None:
1919
"""
2020
Tests:
21-
symforce.sympy.argmax_onehot
22-
symforce.sympy.argmax
21+
sf.argmax_onehot
22+
sf.argmax
2323
Check that the argmax functions return the correct output
2424
"""
2525

@@ -42,8 +42,8 @@ def test_arg_maxes(self) -> None:
4242
def test_arg_maxes_other_sequences(self) -> None:
4343
"""
4444
Tests:
45-
symforce.sympy.argmax_onehot
46-
symforce.sympy.argmax
45+
sf.argmax_onehot
46+
sf.argmax
4747
Check that the argmax functions work on non-list sequences
4848
"""
4949
vals_range = range(5)
@@ -58,6 +58,15 @@ def test_arg_maxes_other_sequences(self) -> None:
5858
self.assertEqual([1, 0, 0, 0], sf.argmax_onehot(vals_arr))
5959
self.assertEqual(0, sf.argmax(vals_arr))
6060

61+
def test_clamp(self) -> None:
62+
"""
63+
Tests:
64+
sf.clamp
65+
"""
66+
self.assertEqual(1, sf.clamp(-10, 1, 5))
67+
self.assertEqual(3, sf.clamp(3, 1, 5))
68+
self.assertEqual(5, sf.clamp(10, 1, 5))
69+
6170

6271
if __name__ == "__main__":
6372
TestCase.main()

0 commit comments

Comments
 (0)