Skip to content

Commit

Permalink
Add missing Vector3 methods and tests (NanoComp#175)
Browse files Browse the repository at this point in the history
  • Loading branch information
ChristopherHogan authored and stevengj committed Jan 19, 2018
1 parent 4f2c559 commit 522fb31
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 3 deletions.
31 changes: 28 additions & 3 deletions python/geom.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ def check_nonnegative(prop, val):
class Vector3(object):

def __init__(self, x=0.0, y=0.0, z=0.0):
self.x = float(x)
self.y = float(y)
self.z = float(z)
self.x = float(x) if type(x) is int else x
self.y = float(y) if type(y) is int else y
self.z = float(z) if type(z) is int else z

def __eq__(self, other):
return self.x == other.x and self.y == other.y and self.z == other.z
Expand Down Expand Up @@ -64,12 +64,37 @@ def scale(self, s):
def dot(self, v):
return self.x * v.x + self.y * v.y + self.z * v.z

def cdot(self, v):
conj_vec = Vector3(self.x.conjugate(),
self.y.conjugate(),
self.z.conjugate())
return conj_vec.dot(v)

def cross(self, v):
x = self.y * v.z - self.z * v.y
y = self.z * v.x - self.x * v.z
z = self.x * v.y - self.y * v.x

return Vector3(x, y, z)

def norm(self):
return math.sqrt(abs(self.dot(self)))

def unit(self):
return self.scale(1 / self.norm())

def close(self, v, tol=1.0e-7):
return (abs(self.x - v.x) <= tol and
abs(self.y - v.y) <= tol and
abs(self.z - v.z) <= tol)

def rotate(self, axis, theta):
u = axis.unit()
vpar = u.scale(u.dot(self))
vcross = u.cross(self)
vperp = self - vpar
return vpar + (vperp.scale(math.cos(theta)) + vcross.scale(math.sin(theta)))


class Medium(object):

Expand Down
35 changes: 35 additions & 0 deletions python/tests/geom.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import unittest
from math import pi
import numpy as np
import meep as mp
import meep.geom as gm
Expand Down Expand Up @@ -253,5 +254,39 @@ def test_use_as_numpy_array(self):
self.assertTrue(type(res) is np.ndarray)
np.testing.assert_array_equal(np.array([20, 20, 20]), res)

def test_cross(self):
v1 = mp.Vector3(x=1)
v2 = mp.Vector3(z=1)
self.assertEqual(v1.cross(v2), mp.Vector3(y=-1))

v1 = mp.Vector3(1, 1)
v2 = mp.Vector3(0, 1, 1)
self.assertEqual(v1.cross(v2), mp.Vector3(1, -1, 1))

def test_cdot(self):
complex_vec1 = mp.Vector3(complex(1, 1), complex(1, 1), complex(1, 1))
complex_vec2 = mp.Vector3(complex(2, 2), complex(2, 2), complex(2, 2))

self.assertEqual(complex_vec1.cdot(complex_vec2), 12 + 0j)

def test_rotate(self):
axis = mp.Vector3(z=1)
v = mp.Vector3(x=1)
res = v.rotate(axis, pi)
self.assertTrue(res.close(mp.Vector3(x=-1)))

def test_close(self):
v1 = mp.Vector3(1e-7)
v2 = mp.Vector3(1e-8)
self.assertTrue(v1.close(v2))

v1 = mp.Vector3(1e-6)
v2 = mp.Vector3(1e-7)
self.assertFalse(v1.close(v2))

v1 = mp.Vector3(1e-10)
v2 = mp.Vector3(1e-11)
self.assertTrue(v1.close(v2, tol=1e-10))

if __name__ == '__main__':
unittest.main()

0 comments on commit 522fb31

Please sign in to comment.