Commit 659d651e authored by Matt Clarkson's avatar Matt Clarkson

Merge branch '2-merge-transforms-pkg'

parents a928a18f 8a8f5288
Pipeline #543 passed with stages
in 9 minutes and 19 seconds
# -*- coding: utf-8 -*-
"""Class implementing a general purpose 4x4 transformation matrix manager."""
import re
import numpy as np
class TransformManager:
"""
Class for managing 4x4 transformation matrices.
This class is NOT designed to be thread-safe.
The transforms are required to be 4x4 matrices.
There is no checking that the upper left 3x3 is
an orthonormal rotation matrix.
Usage::
tm = TransformManager()
# Imagine some example transformations:
t1 = np.eye(4)
t2 = np.eye(4)
t3 = np.eye(4)
# Add transformations to the TransformManager.
tm.add("model2world", t1)
tm.add("hand2eye",t2)
tm.add("hand2world",t3)
# Returns a transform from model to eye,
# by working through the above transforms.
t4 = tm.get("model2eye")
and so on.
"""
def __init__(self):
"""
Initialises an empty repository,
which will be a dictionary of dictionaries.
"""
self.repository = {}
@staticmethod
def is_valid_transform(transform):
"""
Validates the transform as a 4x4 numpy matrix.
:param transform: 4x4 transformation matrix.
:raises: TypeError, ValueError
"""
if not isinstance(transform, np.ndarray):
raise TypeError("transform is not a 2D numpy array")
if transform.shape[0] != 4:
raise ValueError("transform does not have 4 rows")
if transform.shape[1] != 4:
raise ValueError("transform does not have 4 columns")
@staticmethod
def is_valid_name(name):
"""
Validates the name, which must match "^([a-z]+)2([a-z]+)$".
i.e. one or more lowercase letters, followed by the number
2, followed by one or more lowercase letters.
For example::
a2b
model2world
Identity transforms such as model2model raise ValueError.
:param name: the name of the transform, eg. model2world
:raises: TypeError, ValueError
:returns: str, str -- parts of string before and after the 2.
"""
if not isinstance(name, str):
raise TypeError("name is not a string")
if not re.match("^([a-z]+)2([a-z]+)$", name):
raise ValueError("name is incorrectly formatted")
pre, post = name.split("2")
if pre == post:
raise ValueError("you shouldn't request the identity:"
+ pre + "2" + post)
return pre, post
@staticmethod
def flip_name(name):
"""
Returns the inverse name.
:param name: the name of a transformation, e.g. model2world
:returns: str -- the opposite transformation name, e.g. world2model
"""
before, after = TransformManager.is_valid_name(name)
return after + "2" + before
def exists(self, name):
"""
Returns True if the transform exists in the manager,
and False otherwise. Internally this class stores
the inverse. So, if you add model2world, you are
also implicitly adding world2model, so this
method will return True for both the originally
added transform, and its own inverse.
"""
before, after = self.is_valid_name(name)
return after in self.repository.keys() \
and before in self.repository[after].keys()
def count(self):
"""
Returns how many transforms are in the manager.
Internally this class also stores the inverse,
so this method will count those matrices as well.
"""
count = 0
for i in self.repository:
count += len(self.repository[i])
return count
def add(self, name, transform):
"""
Adds a transform called name.
If the name already exists, the corresponding
transform is replaced without warning.
:param name: the name of the transform, e.g. model2world
:param transform: the transform, e.g. 4x4 matrix
"""
before, after = self.is_valid_name(name)
self.is_valid_transform(transform)
if after not in self.repository.keys():
self.repository[after] = {}
if before not in self.repository.keys():
self.repository[before] = {}
self.repository[before][after] = transform
self.repository[after][before] = np.linalg.inv(transform)
def remove(self, name):
"""
Removes a transform from the manager.
If the transform name doesn't exist, will throw ValueError.
:raises: ValueError
"""
before, after = self.is_valid_name(name)
flipped = TransformManager.flip_name(name)
if not self.exists(name):
raise ValueError("name:" + name + ", is not in repository.")
if not self.exists(flipped):
raise ValueError("name:" + flipped + ", is not in repository.")
self.repository[before].pop(after)
self.repository[after].pop(before)
def multiply_point(self, name, points):
"""
Multiplies points (4xN) by the named transform (4x4).
:returns: ndarray -- 4xN matrix of transformed points
:raises: ValueError
"""
if not self.exists(name):
raise ValueError("name:" + name + ", could not be found.")
transform = self.get(name)
return np.matmul(transform, points)
def get(self, name):
"""
Returns the named transform or throws ValueError.
:raises: ValueError
"""
before, after = self.is_valid_name(name)
if before not in self.repository.keys() \
or after not in self.repository.keys():
raise ValueError("name:" + name + ", could not be found.")
result = self.__get_direct(name)
if result is not None:
return result
# If we didn't find it first time,
# search for a list of nodes from after to before.
list_of_nodes = [before]
self.__get_list(before, after, list_of_nodes)
# Multiply the nodes together. __get_list returns them
# in order (from before to after),
# so in the example model2world, model=before
# world=after, so the ordering returned from __get_list
# is from model to world. This is so we can simply
# pre-multiply them in the same order you normally
# do matrix multiplication.
result = np.eye(4)
for node_index in range(0, len(list_of_nodes) - 1):
next_name = list_of_nodes[node_index] \
+ "2" + list_of_nodes[node_index+1]
transform = self.get(next_name)
result = np.matmul(transform, result)
return result
def __get_direct(self, name):
"""
Internal method to return the named transform or None.
"""
before, after = self.is_valid_name(name)
if self.exists(name):
return self.repository[before][after]
return None
def __get_list(self, before, after, list_of_nodes):
"""
Internal method to work out a list of transforms
equivalent to the transform referred to by name.
"""
candidates = self.repository[before]
if after in candidates:
list_of_nodes.append(after)
return
for candidate in candidates:
if candidate in list_of_nodes:
continue
else:
list_of_nodes.append(candidate)
self.__get_list(candidate, after, list_of_nodes)
if list_of_nodes[-1] == after:
break
else:
list_of_nodes.pop()
# -*- coding: utf-8 -*-
import numpy as np
import pytest
import sksurgerycore.transforms.transform_manager as m
test_manager_matrix_tolerance = 0.0001
def test_invalid_empty_transform():
with pytest.raises(TypeError):
m.TransformManager.is_valid_transform(None)
def test_invalid_wrong_rows():
with pytest.raises(ValueError):
m.TransformManager.is_valid_transform(np.ones((2, 4)))
def test_invalid_wrong_cols():
with pytest.raises(ValueError):
m.TransformManager.is_valid_transform(np.ones((4, 2)))
def test_invalid_name_none():
with pytest.raises(TypeError):
m.TransformManager.is_valid_name(None)
def test_invalid_name_integer():
with pytest.raises(TypeError):
m.TransformManager.is_valid_name(1)
def test_invalid_name_empty():
# Empty string, is a string, so its
# a ValueError rather than a TypeError
with pytest.raises(ValueError):
m.TransformManager.is_valid_name("")
def test_invalid_name_no_2():
with pytest.raises(ValueError):
m.TransformManager.is_valid_name("appletobanana")
def test_invalid_name_wrong_number():
with pytest.raises(ValueError):
m.TransformManager.is_valid_name("apple3banana")
def test_invalid_name_upper_case():
with pytest.raises(ValueError):
m.TransformManager.is_valid_name("Apple2Banna")
def test_invalid_name_too_many_parts():
with pytest.raises(ValueError):
m.TransformManager.is_valid_name("apple2banana2pear")
def test_invalid_name_implicit_identity():
with pytest.raises(ValueError):
m.TransformManager.is_valid_name("world2world")
def test_add_invalid_name():
with pytest.raises(ValueError):
tm = m.TransformManager()
tm.add("banana", None)
def test_add_invalid_transform():
with pytest.raises(TypeError):
tm = m.TransformManager()
tm.add("model2world", None)
def test_add_invalid_size():
with pytest.raises(ValueError):
tm = m.TransformManager()
tm.add("model2world", np.eye(5))
def test_remove_invalid_name():
with pytest.raises(ValueError):
tm = m.TransformManager()
tm.remove("banana")
def test_remove_non_existing_name():
with pytest.raises(ValueError):
tm = m.TransformManager()
tm.remove("model2world")
def test_add_remove_valid():
tm = m.TransformManager()
tm.add("model2world", np.eye(4))
tm.add("camera2world", np.eye(4))
assert tm.exists("model2world")
assert tm.exists("camera2world")
assert tm.count() == 4 # as we also store inverse
tm.remove("model2world")
assert tm.count() == 2 # as we also store inverse
tm.repository.pop("camera") # Hack, as member vars aren't private.
assert tm.count() == 1
with pytest.raises(ValueError):
tm.remove("camera2world")
def test_multiply_point():
t = np.eye(4)
t[0][3] = 1
t[1][3] = 2
t[2][3] = 3
model_point = np.ones((4, 1))
expected_world_point = np.ones((4, 1))
expected_world_point[0][0] = 2
expected_world_point[1][0] = 3
expected_world_point[2][0] = 4
expected_world_point[3][0] = 1
tm = m.TransformManager()
tm.add("model2world", t)
with pytest.raises(ValueError):
tm.multiply_point("a2b", model_point)
world_point = tm.multiply_point("model2world", model_point)
assert np.allclose(expected_world_point, world_point, test_manager_matrix_tolerance)
def test_get_exact_match():
t = np.eye(4)
tm = m.TransformManager()
tm.add("model2world", t)
r = tm.get("model2world")
assert np.allclose(t, r, test_manager_matrix_tolerance)
with pytest.raises(ValueError):
r = tm.get("hand2eye")
def test_get_inverse_match():
t = np.eye(4)
t[0][3] = 1
t[1][3] = 2
t[2][3] = 3
tm = m.TransformManager()
tm.add("model2world", t)
r = tm.get("world2model")
assert np.allclose(t, np.linalg.inv(r), test_manager_matrix_tolerance)
def test_get_2_step_path():
a2b = np.eye(4)
a2b[0][3] = 1
a2b[1][3] = 2
a2b[2][3] = 3
b2c = np.eye(4)
b2c[0][3] = 1
b2c[1][3] = 2
b2c[2][3] = 3
tm = m.TransformManager()
tm.add("a2b", a2b)
tm.add("b2c", b2c)
r = tm.get("a2c")
assert np.allclose(np.matmul(b2c, a2b), r, test_manager_matrix_tolerance)
def test_get_2_step_path_with_inverse():
a2b = np.eye(4)
a2b[0][3] = 1
a2b[1][3] = 2
a2b[2][3] = 3
c2b = np.eye(4)
c2b[0][3] = 1
c2b[1][3] = 2
c2b[2][3] = 3
tm = m.TransformManager()
tm.add("a2b", a2b)
tm.add("c2b", c2b)
r = tm.get("a2c")
assert np.allclose(np.matmul(np.linalg.inv(c2b), a2b), r, test_manager_matrix_tolerance)
def create_test_matrix(seed):
mat = np.eye(4)
mat[0][3] = seed
mat[1][3] = seed + 1
mat[2][3] = seed + 2
return mat
def test_get_path_with_diversions():
a2b = create_test_matrix(1)
b2c = create_test_matrix(2)
c2d = create_test_matrix(3)
b2e = create_test_matrix(4)
e2f = create_test_matrix(5)
b2g = create_test_matrix(6)
tm = m.TransformManager()
tm.add("a2b", a2b)
tm.add("b2c", b2c)
tm.add("c2d", c2d)
tm.add("b2e", b2e)
tm.add("e2f", e2f)
tm.add("b2g", b2g)
r = tm.get("a2f")
assert np.allclose(
np.matmul(e2f, np.matmul(b2e, a2b)), r, test_manager_matrix_tolerance)
def test_get_path_with_multiple_path():
a2b = create_test_matrix(1)
b2c = create_test_matrix(2)
c2d = create_test_matrix(3)
a2e = create_test_matrix(3)
e2f = create_test_matrix(2)
f2d = create_test_matrix(1)
tm = m.TransformManager()
tm.add("a2b", a2b)
tm.add("b2c", b2c)
tm.add("c2d", c2d)
tm.add("a2e", a2e)
tm.add("e2f", e2f)
tm.add("f2d", f2d)
r = tm.get("a2d")
assert np.allclose(
np.matmul(c2d, np.matmul(b2c, a2b)), r, test_manager_matrix_tolerance)
def test_get_path_with_y_shape():
a2b = create_test_matrix(1)
b2c = create_test_matrix(2)
d2e = create_test_matrix(3)
e2c = create_test_matrix(4)
c2f = create_test_matrix(5)
f2g = create_test_matrix(6)
tm = m.TransformManager()
tm.add("a2b", a2b)
tm.add("b2c", b2c)
tm.add("d2e", d2e)
tm.add("e2c", e2c)
tm.add("c2f", c2f)
tm.add("f2g", f2g)
r = tm.get("d2g")
assert np.allclose(
np.matmul(np.matmul(c2f, np.matmul(e2c, d2e)), f2g), r, test_manager_matrix_tolerance)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment