from pathlib import Path from typing import Optional, Sequence import cv2 import numpy as np from matplotlib import pyplot as plt from scipy.spatial.transform import Rotation, RigidTransform from sqlalchemy import create_engine, desc from sqlalchemy.orm import Session from util import relative_transform from data_parser import * class VisualOdometry: def __init__(self, K: np.ndarray, index_params: dict[str, int] = {"algorithm": 1, "trees": 5}, search_params: dict[str, int] = {"checks": 50}): """ Constructor Args: K (np.ndarray): Camera Intrinsics Model index_params (dict[str, int], optional): Index parameters for FLANN. Defaults to {"algorithm": 1, "trees": 5}. search_params (dict[str, int], optional): Search parameters for FLANN. Defaults to {"checks": 50}. """ self.K = K self.sift = cv2.SIFT_create() # pyright: ignore[reportAttributeAccessIssue] self.flann = cv2.FlannBasedMatcher( indexParams=index_params, searchParams=search_params) # pyright: ignore[reportArgumentType] def extract_keypoints(self, img: cv2.typing.MatLike) -> tuple[list[cv2.KeyPoint], np.ndarray]: """ Detects keypoints in an image Args: img (cv2.typing.MatLike): _description_ Returns: kp (list[cv2.KeyPoint]): list of keypoints desc (np.ndarray): descriptor of the keypoints """ return self.sift.detectAndCompute(img, None) def match_keypoints(self, desc1: np.ndarray, desc2: np.ndarray, k: int = 2) -> Sequence[Sequence[cv2.DMatch]]: """ Matches keypoints Args: desc1 (np.ndarray): image 1 keypoint description desc2 (np.ndarray): image 2 keypoint description k (int, optional): Defaults to 2. Returns: Sequence[Sequence[cv2.DMatch]]: sequence of matches """ return self.flann.knnMatch(desc1, desc2, k=k) def filter_matches(self, matches: Sequence[Sequence[cv2.DMatch]], distance_threshold: float = 0.7) -> list[cv2.DMatch]: """ Filters out good keypoint matches Args: matches (Sequence[Sequence[cv2.DMatch]]): list of keypoint matches distance_threshold (float, optional): distance percent threshold for filtering. Defaults to 0.7. Returns: list[cv2.DMatch]: list of good matches """ return [m for m, n in matches if m.distance < distance_threshold * n.distance] def estimate_motion(self, kp1: list[cv2.KeyPoint], kp2: list[cv2.KeyPoint], matches: list[cv2.DMatch]) -> RigidTransform: """_summary_ Args: kp1 (list[cv2.KeyPoint]): _description_ kp2 (list[cv2.KeyPoint]): _description_ matches (list[cv2.DMatch]): _description_ Returns: tuple[np.ndarray, np.ndarray]: _description_ """ pts1 = np.float32([kp1[m.queryIdx].pt for m in matches]).reshape(-1, 1, 2) # pyright: ignore[reportArgumentType] pts2 = np.float32([kp2[m.trainIdx].pt for m in matches]).reshape(-1, 1, 2) # pyright: ignore[reportArgumentType] E, _ = cv2.findEssentialMat( points1=pts1, points2=pts2, cameraMatrix=self.K, method=cv2.RANSAC, prob=.999, threshold=1.0) _, R, t, _ = cv2.recoverPose(E, pts1, pts2, cameraMatrix=self.K) return RigidTransform.from_components(translation=t.transpose(), rotation=Rotation.from_matrix(R)) def draw_keypoint_matches(self, img1: cv2.typing.MatLike, kp1: list[cv2.KeyPoint], img2: cv2.typing.MatLike, kp2: list[cv2.KeyPoint], matches: list[cv2.DMatch], output_image: Optional[cv2.typing.MatLike] = None) -> cv2.typing.MatLike: """ Generates an image drawing the keypoint matches between two images in the image Args: img1 (cv2.typing.MatLike): first image kp1 (list[cv2.KeyPoint]): first image keypoints img2 (cv2.typing.MatLike): second image kp2 (list[cv2.KeyPoint]): second image keypoints matches (list[cv2.DMatch]): list of matches accepted output_image (Optional[cv2.typing.MatLike], optional): output image buffer. If None or omitted, a new one will be created. Returns: cv2.typing.MatLike: _description_ """ # Draw matches return cv2.drawMatches(img1, kp1, img2, kp2, matches, output_image, flags=cv2.DrawMatchesFlags_NOT_DRAW_SINGLE_POINTS) # pyright: ignore[reportArgumentType, reportCallIssue] @staticmethod def show_keypoint_matches(match_image: cv2.typing.MatLike) -> None: """ Show image matches Args: match_image (cv2.typing.MatLike): match image """ plt.figure(figsize=(15, 10)) plt.imshow(match_image, cmap='gray') plt.title('Matched Keypoints Between Two Images') plt.axis('off') plt.show() def main(): # Create database for dataset # TODO Move this to dataset library engine = create_engine('sqlite:///:memory:') create_tables(engine) session = Session(bind=engine) # Import dataset data_root = Path('./train1/3d20ae25-5b29-320d-8bae-f03e9dc177b9') dataset = Dataset.import_dataset(data_root, session) # Get Camera camera = session.query(Camera).filter_by(name='ring_front_center').first() if camera is None: raise RuntimeError("Camera not found") # Load images view1 = camera.camera_views[0] view2 = camera.camera_views[1] img1_path = view1.get_path() img2_path = view2.get_path() img1 = cv2.imread(img1_path) img2 = cv2.imread(img2_path) if img1 is None: raise RuntimeError(f"Could not open or find the image {img1_path}") if img2 is None: raise RuntimeError(f"Could not open or find the image {img2_path}") # Create an instance of the VisualOdometry class vo = VisualOdometry(K=camera.get_intrinsics()) # Extract Keypoints kp1, desc1 = vo.extract_keypoints(img1) kp2, desc2 = vo.extract_keypoints(img2) # Match Keypoints matches = vo.match_keypoints(desc1, desc2) # Filter Keypoints good_matches = vo.filter_matches(matches) # Draw matches img_matches = vo.draw_keypoint_matches(img1, kp1, img2, kp2, good_matches) # Show Matches VisualOdometry.show_keypoint_matches(img_matches) # Estimate pose T = vo.estimate_motion(kp1, kp2, good_matches) # Get Recoded Poses T_wa = view1.timestamp.get_vehicle_pose() T_wb = view2.timestamp.get_vehicle_pose() T_ba = relative_transform(T_wa, T_wb) # Print results t, R = T.as_components() t_ba, R_ba = T_ba.as_components() print(f"Calculated Rotation matrix: \n{R.as_matrix()}") print(f"Recorded Rotation matrix: \n{R_ba.as_matrix()}") print(f"Difference Rotation matrix: \n{R_ba.as_matrix() - R.as_matrix()}") print() print(f"Calculated Euler Angles: {R.as_euler('xyz')}") print(f"Recorded Euler Angles: {R_ba.as_euler('xyz')}") print(f"Difference Euler Angles: {R_ba.as_euler('xyz') - R.as_euler('xyz')}") print() print(f"Calculated Quatrains: {R.as_quat()}") print(f"Recorded Quatrains: {R_ba.as_quat()}") print(f"Difference Quatrains: {R_ba.as_quat() - R.as_quat()}") print() print(f"Calculated Translation: {t}") print(f"Recorded Translation: {t_ba}") print(f"Difference Translation: {t_ba - t}") if __name__ == '__main__': main()