Added Recorder pose outputs to compare calculations Added dataset parsing library
228 lines
7.7 KiB
Python
228 lines
7.7 KiB
Python
from pathlib import Path
|
|
from typing import Optional, Sequence
|
|
|
|
import cv2
|
|
|
|
import numpy as np
|
|
from matplotlib import pyplot as plt
|
|
from scipy.spatial.transform import Rotation, RigidTransform
|
|
from sqlalchemy import create_engine, desc
|
|
from sqlalchemy.orm import Session
|
|
|
|
from util import relative_transform
|
|
from data_parser import *
|
|
|
|
|
|
class VisualOdometry:
|
|
|
|
def __init__(self,
|
|
K: np.ndarray,
|
|
index_params: dict[str, int] = {"algorithm": 1, "trees": 5},
|
|
search_params: dict[str, int] = {"checks": 50}):
|
|
""" Constructor
|
|
|
|
Args:
|
|
K (np.ndarray): Camera Intrinsics Model
|
|
index_params (dict[str, int], optional): Index parameters for FLANN. Defaults to {"algorithm": 1, "trees": 5}.
|
|
search_params (dict[str, int], optional): Search parameters for FLANN. Defaults to {"checks": 50}.
|
|
"""
|
|
self.K = K
|
|
|
|
self.sift = cv2.SIFT_create() # pyright: ignore[reportAttributeAccessIssue]
|
|
self.flann = cv2.FlannBasedMatcher(
|
|
indexParams=index_params, searchParams=search_params) # pyright: ignore[reportArgumentType]
|
|
|
|
def extract_keypoints(self, img: cv2.typing.MatLike) -> tuple[list[cv2.KeyPoint], np.ndarray]:
|
|
""" Detects keypoints in an image
|
|
|
|
Args:
|
|
img (cv2.typing.MatLike): _description_
|
|
|
|
Returns:
|
|
kp (list[cv2.KeyPoint]): list of keypoints
|
|
desc (np.ndarray): descriptor of the keypoints
|
|
"""
|
|
return self.sift.detectAndCompute(img, None)
|
|
|
|
def match_keypoints(self,
|
|
desc1: np.ndarray,
|
|
desc2: np.ndarray,
|
|
k: int = 2) -> Sequence[Sequence[cv2.DMatch]]:
|
|
""" Matches keypoints
|
|
|
|
Args:
|
|
desc1 (np.ndarray): image 1 keypoint description
|
|
desc2 (np.ndarray): image 2 keypoint description
|
|
k (int, optional): Defaults to 2.
|
|
|
|
Returns:
|
|
Sequence[Sequence[cv2.DMatch]]: sequence of matches
|
|
"""
|
|
|
|
return self.flann.knnMatch(desc1, desc2, k=k)
|
|
|
|
def filter_matches(self, matches: Sequence[Sequence[cv2.DMatch]], distance_threshold: float = 0.7) -> list[cv2.DMatch]:
|
|
""" Filters out good keypoint matches
|
|
|
|
Args:
|
|
matches (Sequence[Sequence[cv2.DMatch]]): list of keypoint matches
|
|
distance_threshold (float, optional): distance percent threshold for filtering. Defaults to 0.7.
|
|
|
|
Returns:
|
|
list[cv2.DMatch]: list of good matches
|
|
"""
|
|
return [m for m, n in matches if m.distance < distance_threshold * n.distance]
|
|
|
|
def estimate_motion(self, kp1: list[cv2.KeyPoint], kp2: list[cv2.KeyPoint], matches: list[cv2.DMatch]) -> RigidTransform:
|
|
"""_summary_
|
|
|
|
Args:
|
|
kp1 (list[cv2.KeyPoint]): _description_
|
|
kp2 (list[cv2.KeyPoint]): _description_
|
|
matches (list[cv2.DMatch]): _description_
|
|
|
|
Returns:
|
|
tuple[np.ndarray, np.ndarray]: _description_
|
|
"""
|
|
|
|
pts1 = np.float32([kp1[m.queryIdx].pt for m in matches]).reshape(-1, 1, 2) # pyright: ignore[reportArgumentType]
|
|
pts2 = np.float32([kp2[m.trainIdx].pt for m in matches]).reshape(-1, 1, 2) # pyright: ignore[reportArgumentType]
|
|
|
|
E, _ = cv2.findEssentialMat(
|
|
points1=pts1,
|
|
points2=pts2,
|
|
cameraMatrix=self.K,
|
|
method=cv2.RANSAC,
|
|
prob=.999,
|
|
threshold=1.0)
|
|
|
|
_, R, t, _ = cv2.recoverPose(E, pts1, pts2, cameraMatrix=self.K)
|
|
|
|
return RigidTransform.from_components(translation=t.transpose(), rotation=Rotation.from_matrix(R))
|
|
|
|
|
|
def draw_keypoint_matches(self,
|
|
img1: cv2.typing.MatLike,
|
|
kp1: list[cv2.KeyPoint],
|
|
img2: cv2.typing.MatLike,
|
|
kp2: list[cv2.KeyPoint],
|
|
matches: list[cv2.DMatch],
|
|
output_image: Optional[cv2.typing.MatLike] = None) -> cv2.typing.MatLike:
|
|
""" Generates an image drawing the keypoint matches between two images in the image
|
|
|
|
Args:
|
|
img1 (cv2.typing.MatLike): first image
|
|
kp1 (list[cv2.KeyPoint]): first image keypoints
|
|
img2 (cv2.typing.MatLike): second image
|
|
kp2 (list[cv2.KeyPoint]): second image keypoints
|
|
matches (list[cv2.DMatch]): list of matches accepted
|
|
output_image (Optional[cv2.typing.MatLike], optional): output image buffer. If None or omitted, a new one will be created.
|
|
|
|
Returns:
|
|
cv2.typing.MatLike: _description_
|
|
"""
|
|
|
|
# Draw matches
|
|
|
|
return cv2.drawMatches(img1, kp1, img2, kp2, matches, output_image, flags=cv2.DrawMatchesFlags_NOT_DRAW_SINGLE_POINTS) # pyright: ignore[reportArgumentType, reportCallIssue]
|
|
|
|
@staticmethod
|
|
def show_keypoint_matches(match_image: cv2.typing.MatLike) -> None:
|
|
""" Show image matches
|
|
|
|
Args:
|
|
match_image (cv2.typing.MatLike): match image
|
|
"""
|
|
plt.figure(figsize=(15, 10))
|
|
plt.imshow(match_image, cmap='gray')
|
|
plt.title('Matched Keypoints Between Two Images')
|
|
plt.axis('off')
|
|
plt.show()
|
|
|
|
|
|
|
|
def main():
|
|
# Create database for dataset
|
|
# TODO Move this to dataset library
|
|
engine = create_engine('sqlite:///:memory:')
|
|
create_tables(engine)
|
|
session = Session(bind=engine)
|
|
|
|
# Import dataset
|
|
data_root = Path('./train1/3d20ae25-5b29-320d-8bae-f03e9dc177b9')
|
|
dataset = Dataset.import_dataset(data_root, session)
|
|
|
|
# Get Camera
|
|
camera = session.query(Camera).filter_by(name='ring_front_center').first()
|
|
|
|
if camera is None:
|
|
raise RuntimeError("Camera not found")
|
|
|
|
# Load images
|
|
view1 = camera.camera_views[0]
|
|
view2 = camera.camera_views[1]
|
|
img1_path = view1.get_path()
|
|
img2_path = view2.get_path()
|
|
img1 = cv2.imread(img1_path)
|
|
img2 = cv2.imread(img2_path)
|
|
|
|
if img1 is None:
|
|
raise RuntimeError(f"Could not open or find the image {img1_path}")
|
|
|
|
if img2 is None:
|
|
raise RuntimeError(f"Could not open or find the image {img2_path}")
|
|
|
|
# Create an instance of the VisualOdometry class
|
|
vo = VisualOdometry(K=camera.get_intrinsics())
|
|
|
|
# Extract Keypoints
|
|
kp1, desc1 = vo.extract_keypoints(img1)
|
|
kp2, desc2 = vo.extract_keypoints(img2)
|
|
|
|
# Match Keypoints
|
|
matches = vo.match_keypoints(desc1, desc2)
|
|
|
|
# Filter Keypoints
|
|
good_matches = vo.filter_matches(matches)
|
|
|
|
# Draw matches
|
|
img_matches = vo.draw_keypoint_matches(img1, kp1, img2, kp2, good_matches)
|
|
|
|
# Show Matches
|
|
VisualOdometry.show_keypoint_matches(img_matches)
|
|
|
|
# Estimate pose
|
|
T = vo.estimate_motion(kp1, kp2, good_matches)
|
|
|
|
# Get Recoded Poses
|
|
T_wa = view1.timestamp.get_vehicle_pose()
|
|
T_wb = view2.timestamp.get_vehicle_pose()
|
|
T_ba = relative_transform(T_wa, T_wb)
|
|
|
|
# Print results
|
|
t, R = T.as_components()
|
|
t_ba, R_ba = T_ba.as_components()
|
|
|
|
print(f"Calculated Rotation matrix: \n{R.as_matrix()}")
|
|
print(f"Recorded Rotation matrix: \n{R_ba.as_matrix()}")
|
|
print(f"Difference Rotation matrix: \n{R_ba.as_matrix() - R.as_matrix()}")
|
|
print()
|
|
|
|
print(f"Calculated Euler Angles: {R.as_euler('xyz')}")
|
|
print(f"Recorded Euler Angles: {R_ba.as_euler('xyz')}")
|
|
print(f"Difference Euler Angles: {R_ba.as_euler('xyz') - R.as_euler('xyz')}")
|
|
print()
|
|
|
|
print(f"Calculated Quatrains: {R.as_quat()}")
|
|
print(f"Recorded Quatrains: {R_ba.as_quat()}")
|
|
print(f"Difference Quatrains: {R_ba.as_quat() - R.as_quat()}")
|
|
print()
|
|
|
|
print(f"Calculated Translation: {t}")
|
|
print(f"Recorded Translation: {t_ba}")
|
|
print(f"Difference Translation: {t_ba - t}")
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|