Skip to content

Commit 7d4f5c5

Browse files
[SymForce] Replace sympy.count_ops w/ better impl
In this commit, I just copy the existing sympy implementation into symforce. The only changes made to it are: - Adding typing to the signature - Reformatting - Adding # type: ignore to various places - Adding # pylint: disable=... to various places I opted to do it this way to be minimally invasive to make it easier to identify which changes we've made vs what was already present. Topic: improve_sympy_count_ops GitOrigin-RevId: a37529f77803e0991aadb31d4afba94029547396
1 parent 4b8b9f3 commit 7d4f5c5

File tree

8 files changed

+296
-5
lines changed

8 files changed

+296
-5
lines changed

symforce/_sympy_count_ops.py

+263
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,263 @@
1+
# ----------------------------------------------------------------------------
2+
# SymForce - Copyright 2022, Skydio, Inc.
3+
# This source code is under the Apache 2.0 license found in the LICENSE file.
4+
# ----------------------------------------------------------------------------
5+
6+
"""
7+
This file defines a modified version of sympy.count_ops which is more suited
8+
to our needs.
9+
10+
Based on: https://github.com/sympy/sympy/blob/c1deac1ce6c989ee2c66fc505595603cca77681f/sympy/core/function.py#L3059
11+
with light modification.
12+
13+
Note, there are a few instances of of # type: ignore, and # pylint: disable=...
14+
While the code could be modified to pass linting without these, we wanted to
15+
change as little as possible from the original sympy code (to better distinguish
16+
our own material changes).
17+
18+
Currently, the body of of count_ops is the body as found in sympy, modified only
19+
with those comments, formatting, and material changes to the implementation.
20+
"""
21+
22+
import sympy
23+
from sympy import sympify
24+
from sympy import Expr
25+
from sympy import Symbol
26+
from sympy import S
27+
from sympy import Basic
28+
from sympy import Derivative
29+
from sympy import Add
30+
from sympy.core.compatibility import iterable
31+
from sympy.core.function import UndefinedFunction
32+
from sympy.core.operations import LatticeOp
33+
34+
from symforce import typing as T
35+
from symforce.typing import Dict
36+
37+
38+
def _coeff_isneg(a: sympy.Expr) -> bool:
39+
"""Return True if the leading Number is negative.
40+
Examples
41+
========
42+
>>> from sympy.core.function import _coeff_isneg
43+
>>> from sympy import S, Symbol, oo, pi
44+
>>> _coeff_isneg(-3*pi)
45+
True
46+
>>> _coeff_isneg(S(3))
47+
False
48+
>>> _coeff_isneg(-oo)
49+
True
50+
>>> _coeff_isneg(Symbol('n', negative=True)) # coeff is 1
51+
False
52+
For matrix expressions:
53+
>>> from sympy import MatrixSymbol, sqrt
54+
>>> A = MatrixSymbol("A", 3, 3)
55+
>>> _coeff_isneg(-sqrt(2)*A)
56+
True
57+
>>> _coeff_isneg(sqrt(2)*A)
58+
False
59+
"""
60+
61+
if a.is_MatMul:
62+
a = a.args[0]
63+
if a.is_Mul:
64+
a = a.args[0]
65+
return a.is_Number and a.is_extended_negative
66+
67+
68+
def count_ops(expr: sympy.Expr, visual: bool = False) -> T.Union[sympy.Expr, int]:
69+
"""
70+
Return a representation (integer or expression) of the operations in expr.
71+
Parameters
72+
==========
73+
expr : Expr
74+
If expr is an iterable, the sum of the op counts of the
75+
items will be returned.
76+
visual : bool, optional
77+
If ``False`` (default) then the sum of the coefficients of the
78+
visual expression will be returned.
79+
If ``True`` then the number of each type of operation is shown
80+
with the core class types (or their virtual equivalent) multiplied by the
81+
number of times they occur.
82+
Examples
83+
========
84+
>>> from sympy.abc import a, b, x, y
85+
>>> from sympy import sin, count_ops
86+
Although there isn't a SUB object, minus signs are interpreted as
87+
either negations or subtractions:
88+
>>> (x - y).count_ops(visual=True)
89+
SUB
90+
>>> (-x).count_ops(visual=True)
91+
NEG
92+
Here, there are two Adds and a Pow:
93+
>>> (1 + a + b**2).count_ops(visual=True)
94+
2*ADD + POW
95+
In the following, an Add, Mul, Pow and two functions:
96+
>>> (sin(x)*x + sin(x)**2).count_ops(visual=True)
97+
ADD + MUL + POW + 2*SIN
98+
for a total of 5:
99+
>>> (sin(x)*x + sin(x)**2).count_ops(visual=False)
100+
5
101+
Note that "what you type" is not always what you get. The expression
102+
1/x/y is translated by sympy into 1/(x*y) so it gives a DIV and MUL rather
103+
than two DIVs:
104+
>>> (1/x/y).count_ops(visual=True)
105+
DIV + MUL
106+
The visual option can be used to demonstrate the difference in
107+
operations for expressions in different forms. Here, the Horner
108+
representation is compared with the expanded form of a polynomial:
109+
>>> eq=x*(1 + x*(2 + x*(3 + x)))
110+
>>> count_ops(eq.expand(), visual=True) - count_ops(eq, visual=True)
111+
-MUL + 3*POW
112+
The count_ops function also handles iterables:
113+
>>> count_ops([x, sin(x), None, True, x + 2], visual=False)
114+
2
115+
>>> count_ops([x, sin(x), None, True, x + 2], visual=True)
116+
ADD + SIN
117+
>>> count_ops({x: sin(x), x + 2: y + 1}, visual=True)
118+
2*ADD + SIN
119+
"""
120+
from sympy import Integral, Sum
121+
from sympy.core.relational import Relational
122+
from sympy.simplify.radsimp import fraction
123+
from sympy.logic.boolalg import BooleanFunction
124+
from sympy.utilities.misc import func_name
125+
126+
expr = sympify(expr)
127+
128+
# pylint: disable=too-many-nested-blocks
129+
if isinstance(expr, Expr) and not expr.is_Relational:
130+
131+
ops = []
132+
args = [expr]
133+
NEG = Symbol("NEG")
134+
DIV = Symbol("DIV")
135+
SUB = Symbol("SUB")
136+
ADD = Symbol("ADD")
137+
EXP = Symbol("EXP")
138+
while args:
139+
a = args.pop()
140+
141+
# if the following fails because the object is
142+
# not Basic type, then the object should be fixed
143+
# since it is the intention that all args of Basic
144+
# should themselves be Basic
145+
if a.is_Mul or a.is_MatMul:
146+
if _coeff_isneg(a):
147+
ops.append(NEG)
148+
if a.args[0] is S.NegativeOne:
149+
a = a.as_two_terms()[1]
150+
else:
151+
a = -a
152+
n, d = fraction(a)
153+
if n.is_Integer:
154+
ops.append(DIV)
155+
if n < 0:
156+
ops.append(NEG)
157+
args.append(d)
158+
continue # won't be -Mul but could be Add
159+
elif d is not S.One:
160+
if not d.is_Integer:
161+
args.append(d)
162+
ops.append(DIV)
163+
args.append(n)
164+
continue # could be -Mul
165+
elif a.is_Add or a.is_MatAdd:
166+
aargs = list(a.args)
167+
negs = 0
168+
for i, ai in enumerate(aargs):
169+
if _coeff_isneg(ai):
170+
negs += 1
171+
args.append(-ai)
172+
if i > 0:
173+
ops.append(SUB)
174+
else:
175+
args.append(ai)
176+
if i > 0:
177+
ops.append(ADD)
178+
if negs == len(aargs): # -x - y = NEG + SUB
179+
ops.append(NEG)
180+
elif _coeff_isneg(aargs[0]): # -x + y = SUB, but already recorded ADD
181+
ops.append(SUB - ADD)
182+
continue
183+
if a.is_Pow and a.exp is S.NegativeOne:
184+
ops.append(DIV)
185+
args.append(a.base) # won't be -Mul but could be Add
186+
continue
187+
if a == S.Exp1:
188+
ops.append(EXP)
189+
continue
190+
if a.is_Pow and a.base == S.Exp1:
191+
ops.append(EXP)
192+
args.append(a.exp)
193+
continue
194+
if a.is_Mul or isinstance(a, LatticeOp):
195+
o = Symbol(a.func.__name__.upper())
196+
# count the args
197+
ops.append(o * (len(a.args) - 1))
198+
elif a.args and ( # pylint: disable=too-many-boolean-expressions
199+
a.is_Pow # pylint: disable=consider-merging-isinstance
200+
or a.is_Function
201+
or isinstance(a, Derivative)
202+
or isinstance(a, Integral)
203+
or isinstance(a, Sum)
204+
):
205+
# if it's not in the list above we don't
206+
# consider a.func something to count, e.g.
207+
# Tuple, MatrixSymbol, etc...
208+
if isinstance(a.func, UndefinedFunction):
209+
o = Symbol("FUNC_" + a.func.__name__.upper())
210+
else:
211+
o = Symbol(a.func.__name__.upper())
212+
ops.append(o)
213+
214+
if not a.is_Symbol:
215+
args.extend(a.args)
216+
217+
elif isinstance(expr, Dict):
218+
ops = [count_ops(k, visual=visual) + count_ops(v, visual=visual) for k, v in expr.items()]
219+
elif iterable(expr):
220+
ops = [count_ops(i, visual=visual) for i in expr]
221+
elif isinstance(expr, (Relational, BooleanFunction)):
222+
ops = []
223+
for arg in expr.args:
224+
ops.append(count_ops(arg, visual=True))
225+
o = Symbol(func_name(expr, short=True).upper())
226+
ops.append(o)
227+
elif not isinstance(expr, Basic):
228+
ops = []
229+
else: # it's Basic not isinstance(expr, Expr):
230+
if not isinstance(expr, Basic):
231+
raise TypeError("Invalid type of expr")
232+
else:
233+
ops = []
234+
args = [expr]
235+
while args:
236+
a = args.pop()
237+
238+
if a.args:
239+
o = Symbol(type(a).__name__.upper())
240+
if a.is_Boolean:
241+
ops.append(o * (len(a.args) - 1))
242+
else:
243+
ops.append(o)
244+
args.extend(a.args)
245+
246+
if not ops:
247+
if visual:
248+
return S.Zero
249+
return 0
250+
251+
ops = Add(*ops)
252+
253+
if visual:
254+
return ops
255+
256+
# NOTE(brad): This type ignore is needed becasue above we change the type of ops
257+
# from a list to a sympy.Add. This could be fixed by changing the names of the
258+
# variables, but this is how it was done in sympy's implementation and I'd rather
259+
# not make changes that don't change functionality (to reduce noise).
260+
if ops.is_Number: # type: ignore
261+
return int(ops) # type: ignore
262+
263+
return sum(int((a.args or [1])[0]) for a in Add.make_args(ops))

symforce/initialization.py

+12
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
from symforce import logger
1414
from symforce import typing as T
1515

16+
from . import _sympy_count_ops
17+
1618

1719
def modify_symbolic_api(sympy_module: T.Any) -> None:
1820
"""
@@ -28,6 +30,7 @@ def modify_symbolic_api(sympy_module: T.Any) -> None:
2830
add_scoping(sympy_module)
2931
add_custom_methods(sympy_module)
3032
override_solve(sympy_module)
33+
override_count_ops(sympy_module)
3134

3235

3336
def override_symbol_new(sympy_module: T.Any) -> None:
@@ -261,6 +264,15 @@ def solve(*args: T.Any, **kwargs: T.Any) -> T.List[T.Scalar]:
261264
raise NotImplementedError(f"Unknown backend: '{sympy_module.__name__}'")
262265

263266

267+
def override_count_ops(sympy_module: T.Type) -> None:
268+
"""
269+
Patch count_ops to yield more reasonable outputs from the perspective of generated code. Only
270+
sympy.count_ops is modified here as the symengine.count_ops is modified directly.
271+
"""
272+
if sympy_module.__name__ == "sympy":
273+
sympy_module.count_ops = _sympy_count_ops.count_ops
274+
275+
264276
def add_custom_methods(sympy_module: T.Type) -> None:
265277
"""
266278
Add safe helper methods to the symbolic API.

test/count_ops_test.py

+16
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,22 @@ def test_division(self) -> None:
5656
self.assertEqual(2, sm.count_ops(x * y ** 2))
5757
self.assertEqual(2, sm.count_ops(x * y ** (-2)))
5858

59+
def test_constants(self) -> None:
60+
"""
61+
Tests that constants are no ops.
62+
"""
63+
with self.subTest(msg=f"{sm.__name__} counts decimals as 0 ops"):
64+
self.assertEqual(0, sm.count_ops(1.1))
65+
self.assertEqual(0, sm.count_ops(-1.1))
66+
67+
with self.subTest(msg=f"{sm.__name__} counts integers as 0 ops"):
68+
self.assertEqual(0, sm.count_ops(2))
69+
self.assertEqual(0, sm.count_ops(-sm.S(2)))
70+
71+
with self.subTest(msg=f"{sm.__name__} counts rationals as 0 ops"):
72+
self.assertEqual(0, sm.count_ops(sm.Rational(2, 3)))
73+
self.assertEqual(0, sm.count_ops(sm.Rational(-2, 3)))
74+
5975

6076
if __name__ == "__main__":
6177
TestCase.main()

test/symforce_function_codegen_test_data/sympy/az_el_from_point.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ template <typename Scalar>
2929
Eigen::Matrix<Scalar, 2, 1> AzElFromPoint(const sym::Pose3<Scalar>& nav_T_cam,
3030
const Eigen::Matrix<Scalar, 3, 1>& nav_t_point,
3131
const Scalar epsilon) {
32-
// Total ops: 78
32+
// Total ops: 77
3333

3434
// Input arrays
3535
const Eigen::Matrix<Scalar, 7, 1>& _nav_T_cam = nav_T_cam.Data();

test/symforce_function_codegen_test_data/sympy/az_el_from_point.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def az_el_from_point(nav_T_cam, nav_t_point, epsilon):
2929
geo.Matrix: (azimuth, elevation)
3030
3131
"""
32-
# Total ops: 78
32+
# Total ops: 77
3333

3434
# Input arrays
3535
_nav_T_cam = nav_T_cam.data

test/symforce_function_codegen_test_data/sympy/codegen_nan_test_data/cpp/symforce/codegen_nan_test/identity_dist_jacobian.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ namespace codegen_nan_test {
2525
*/
2626
template <typename Scalar>
2727
Scalar IdentityDistJacobian(const sym::Rot3<Scalar>& R1, const Scalar e) {
28-
// Total ops: 41
28+
// Total ops: 39
2929

3030
// Input arrays
3131
const Eigen::Matrix<Scalar, 4, 1>& _R1 = R1.Data();

test/symforce_function_codegen_test_data/sympy/gnc_test_data/cpp/symforce/gnc_factors/barron_factor.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ void BarronFactor(const Eigen::Matrix<Scalar, 5, 1>& x, const Eigen::Matrix<Scal
3434
Eigen::Matrix<Scalar, 5, 5>* const jacobian = nullptr,
3535
Eigen::Matrix<Scalar, 5, 5>* const hessian = nullptr,
3636
Eigen::Matrix<Scalar, 5, 1>* const rhs = nullptr) {
37-
// Total ops: 158
37+
// Total ops: 152
3838

3939
// Input arrays
4040

test/symforce_function_codegen_test_data/sympy/heaviside.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ namespace cpp_code_printer_test {
2424
*/
2525
template <typename Scalar>
2626
Scalar Heaviside(const Scalar x) {
27-
// Total ops: 2
27+
// Total ops: 1
2828

2929
// Input arrays
3030

0 commit comments

Comments
 (0)