diff --git a/.gitignore b/.gitignore index 37cbe99..f998999 100644 --- a/.gitignore +++ b/.gitignore @@ -2,4 +2,5 @@ .venv .vscode train1 -claude.md \ No newline at end of file +claude.md +__pycache__ \ No newline at end of file diff --git a/data_parser/__init__.py b/data_parser/__init__.py new file mode 100644 index 0000000..ee177f7 --- /dev/null +++ b/data_parser/__init__.py @@ -0,0 +1 @@ +from data_parser.data_parser import create_tables, Dataset, Timestamp, Camera, CameraView \ No newline at end of file diff --git a/data_parser/data_parser.py b/data_parser/data_parser.py new file mode 100644 index 0000000..9242b93 --- /dev/null +++ b/data_parser/data_parser.py @@ -0,0 +1,409 @@ +import json +from os import PathLike +from pathlib import Path +from typing import Optional + +import numpy as np +from scipy.spatial.transform import RigidTransform, Rotation + +from sqlalchemy import Column, Engine, Float, Integer, Enum, ForeignKey, String +from sqlalchemy.orm import DeclarativeBase, Session, relationship + +import re + +class ModelBase(DeclarativeBase): + pass + +metadata = ModelBase.metadata + +def create_tables(engine:Engine): + metadata.create_all(engine) + +ts_pose_re = re.compile(r"city_SE3_egovehicle_(?P\d*)") +cam_name_re = re.compile(r"image_raw_(?P.*)") +camview_fn_re = re.compile(r"(?P[^\d]*)_(?P\d*)") + + +def json_to_pose(data: dict) -> RigidTransform: + """ Creates a pose from an Argoverse pose json + + Args: + data (dict): json contents + + Returns: + RigidTransform: pose in the json data + """ + # TODO improve json validation + + if 'rotation' in data: + if 'coefficients' in data['rotation']: + rot_data = data['rotation']['coefficients'] + else: + rot_data = data['rotation'] + else: + raise RuntimeError('No rotation present in pose data') + + if 'translation' in data: + trans_data = data['translation'] + else: + raise RuntimeError('No translation present in pose data') + + + return RigidTransform.from_components( + np.array(trans_data), + Rotation.from_quat(rot_data)) + + +class Dataset(ModelBase): + """ Represents a dataset """ + __tablename__ = "datasets" + id = Column(Integer, primary_key=True, autoincrement=True) + root_path = Column(String(256)) + key = Column(String(36)) + city = Column(String(3), default="") + + timestamps = relationship("Timestamp", back_populates='dataset') + cameras = relationship("Camera", back_populates='dataset') + camera_views = relationship("CameraView", back_populates='dataset') + + def __init__(self, root_path: PathLike, key: str, city: str = ''): + """ Constructor + + Args: + key (str): key of the dataset + root_path (PathLike): Root of the data directory + """ + self.key = key + self.city = city if len(city) <= 3 else city[0:3] + self.root_path = str(root_path) + + @staticmethod + def import_dataset(root_path: PathLike, session:Session) -> 'Dataset': + """ Imports a dataset + + Args: + root_path (PathLike): root path to the dataset + session (Session): SQLAlchemy session to use + + Returns: + Dataset: Reference to the dataset + """ + + root_path = Path(root_path) + key = root_path.stem + city = '' + + city_info_path = root_path / "city_info.json" + calib_path = root_path / "vehicle_calibration_info.json" + pose_root = root_path / "poses" + + # Get City Information + if city_info_path.exists(): + with open(city_info_path, 'r') as f: + data = json.load(f) + city = data['city'] if 'city' in data else '' + + # Create the new dataset + dataset = Dataset(root_path, key, city) + + # Create timestamps from poses + for path in pose_root.glob("*.json"): + # TODO Validate json + timestamp = Timestamp.from_json(path, dataset, session) + + # Create Cameras + with open(calib_path, 'r') as f: + data = json.load(f) + + # TODO Validate file + cameras = [] + for camera_data in data['camera_data_']: + camera = Camera.from_json(camera_data, dataset, session) + + if camera is not None: + cameras.append(camera) + + # Create Camera Views + for camera in cameras: + camera_dir = root_path / camera.name + + for path in camera_dir.glob("*.*"): + CameraView.from_path(path, dataset, session) + + return dataset + + +class Timestamp(ModelBase): + """ Correlates data at a given timestamp """ + __tablename__ = "timestamps" + id = Column(Integer, primary_key=True, autoincrement=True) + time = Column(Integer) + + vehicle_pose_x = Column(Float) # x component + vehicle_pose_y = Column(Float) # y component + vehicle_pose_z = Column(Float) # z component + vehicle_pose_q0 = Column(Float) # Quaternion component 0 + vehicle_pose_q1 = Column(Float) # Quaternion component 1 + vehicle_pose_q2 = Column(Float) # Quaternion component 2 + vehicle_pose_q3 = Column(Float) # Quaternion component 3 + + dataset_id = Column(Integer, ForeignKey('datasets.id'), nullable=False) + dataset = relationship('Dataset', back_populates='timestamps') + + camera_views = relationship("CameraView", back_populates="timestamp") + + def __init__(self, time: int, vehicle_pose: RigidTransform, dataset: Dataset): + """ Constructor + + Args: + time (int): nanosecond time for the timestamp + vehicle_pose (RigidTransform): vehicle pose in world coordinates + dataset (Dataset): dataset this timestamp belongs to + + """ + + self.time = time + + self.set_vehicle_pose(vehicle_pose) + + self.dataset = dataset + + def set_vehicle_pose(self, vehicle_pose: RigidTransform): + """ Sets the vehicle's pose + + Args: + vehicle_pose (RigidTransform): vehicle pose in world coordinates + """ + t, R = vehicle_pose.as_components() + quat = R.as_quat() + + self.vehicle_pose_x = t[0] + self.vehicle_pose_y = t[1] + self.vehicle_pose_z = t[2] + self.vehicle_pose_q0 = quat[0] + self.vehicle_pose_q1 = quat[0] + self.vehicle_pose_q2 = quat[0] + self.vehicle_pose_q3 = quat[0] + + def get_vehicle_pose(self) -> RigidTransform: + """ Returns the vehicle pose as a RigidTransform object """ + return RigidTransform.from_components( + translation=np.array( + [self.vehicle_pose_x, self.vehicle_pose_y, self.vehicle_pose_z]), + rotation=Rotation.from_quat([self.vehicle_pose_q0, self.vehicle_pose_q1, self.vehicle_pose_q2, self.vehicle_pose_q3])) + + @staticmethod + def from_json(path: PathLike, dataset: Dataset, session:Session) -> Optional["Timestamp"]: + """ Creates a timestamp object from a pose file path and dataset object + + Args: + path (PathLike): Path to pose file + dataset (Dataset): Dataset object to which this timestamp + session (Session): SQLAlchemy session to use + + Returns: + Optional[Timestamp]: new timestamp object if it is successfully created, None otherwise + """ + result = None + path = Path(path) + ts_pose_match = ts_pose_re.match(str(path.stem)) + if ts_pose_match is not None: + time = int(ts_pose_match.group("timestamp")) + + with open(path, 'r') as f: + # TODO Add exception handling + result = Timestamp(time, json_to_pose(json.load(f)), dataset) + session.add(result) + session.commit() + + return result + + +class Camera(ModelBase): + """ Represents a camera in the dataset """ + __tablename__ = "cameras" + + id = Column(Integer, primary_key=True) + name = Column(String(20), nullable=False) + + fx = Column(Float, nullable=False) + fy = Column(Float, nullable=False) + cx = Column(Float, nullable=False) + cy = Column(Float, nullable=False) + + extrinsics_x = Column(Float) # x component + extrinsics_y = Column(Float) # y component + extrinsics_z = Column(Float) # z component + extrinsics_q0 = Column(Float) # Quaternion component 0 + extrinsics_q1 = Column(Float) # Quaternion component 1 + extrinsics_q2 = Column(Float) # Quaternion component 2 + extrinsics_q3 = Column(Float) # Quaternion component 3 + + dataset_id = Column(Integer, ForeignKey('datasets.id'), nullable=False) + dataset = relationship('Dataset', back_populates='cameras') + + camera_views = relationship("CameraView", back_populates="camera") + + def __init__(self, name: str, intrinsics: np.ndarray, extrinsics: RigidTransform, dataset: Dataset): + """ Constructor + + Args: + name (str): Name of the camera + intrinsics (np.ndarray): Camera intrinsics matrix + extrinsics (RigidTransform): Camera extrinsics transform + dataset (Dataset): dataset this camera belongs to + """ + self.name = name + self.set_intrinsics(intrinsics) + self.set_extrinsics(extrinsics) + self.dataset = dataset + + def set_intrinsics(self, intrinsics: np.ndarray): + """ Sets the camera intrinsics matrix + + Args: + intrinsics (np.ndarray): Camera intrinsics matrix + """ + self.fx = intrinsics[0, 0] + self.fy = intrinsics[1, 1] + self.cx = intrinsics[0, 2] + self.cy = intrinsics[1, 2] + + def set_extrinsics(self, extrinsics: RigidTransform): + """ Sets the camera extrinsics + + Args: + extrinsics (RigidTransform): camera extrinsics in vehicle coordinates + """ + t, R = extrinsics.as_components() + quat = R.as_quat() + + self.extrinsics_x = t[0] + self.extrinsics_y = t[1] + self.extrinsics_z = t[2] + self.extrinsics_q0 = quat[0] + self.extrinsics_q0 = quat[0] + self.extrinsics_q0 = quat[0] + self.extrinsics_q0 = quat[0] + + def get_intrinsics(self) -> np.ndarray: + """ Gets the camera intrinsics matrix + + Returns: + np.ndarray: 3 x 3 camera intrinsics matrix + + """ + + return np.array([[self.fx, 0, self.cx], + [0, self.fy, self.cy], + [0, 0, 1]]) + + def get_extrinsics(self) -> RigidTransform: + """ Returns the extrinsics as a RigidTransform object """ + return RigidTransform.from_components( + translation=np.array( + [self.extrinsics_x, self.extrinsics_y, self.extrinsics_z]), + rotation=Rotation.from_quat([self.extrinsics_q0, self.extrinsics_q1, self.extrinsics_q2, self.extrinsics_q3])) + + @staticmethod + def from_json(data: dict, dataset: Dataset, session:Session) -> Optional["Camera"]: + """ Creates a camera from json data + + Args: + data (dict): camera data dict + dataset (Dataset): dataset object to which the camera belongs + session (Session): SQLAlchemy session to use + + Returns: + Camera: Camera object if the data is valid, None otherwise + + """ + result = None + # TODO Validate data + + name_match = cam_name_re.match(data['key']) + + if name_match is not None: + name = name_match.group('name') + + values = data['value'] + + intrinsics = np.array([[values['focal_length_x_px_'], 0, values['focal_center_x_px_']], + [0, values['focal_length_y_px_'], values['focal_center_x_px_']], + [0, 0, 1]]) + extrinsics = json_to_pose(values['vehicle_SE3_camera_']) + + result = Camera(name, intrinsics, extrinsics, dataset) + session.add(result) + session.commit() + + return result + +class CameraView(ModelBase): + """ Represents a camera view at a given timestamp """ + __tablename__ = "camera_views" + + id = Column(Integer, primary_key=True) + + ext = Column(String(3), nullable=False) + + dataset_id = Column(Integer, ForeignKey('datasets.id'), nullable=False) + dataset = relationship('Dataset', back_populates='camera_views') + + timestamp_id = Column(Integer, ForeignKey('timestamps.id'), nullable=False) + timestamp = relationship('Timestamp', back_populates='camera_views') + + camera_id = Column(Integer, ForeignKey('cameras.id'), nullable=False) + camera = relationship('Camera', back_populates='camera_views', order_by=Timestamp.time) + + def __init__(self, ext: str, dataset: Dataset, timestamp: Timestamp, camera: Camera): + """ Constructor + + Args: + dataset (Dataset): dataset this camera view belongs to + timestamp (Timestamp): timestamp this camera view was taken + camera (Camera): camera that took this camera view + """ + self.ext = ext + self.dataset = dataset + self.timestamp = timestamp + self.camera = camera + + def get_path(self) -> Path: + """ Get the path to the camera view image file + + Returns: + Path: Path to the camera view image file + """ + return Path(str(self.dataset.root_path)) / str(self.camera.name) / f"{self.camera.name}_{self.timestamp.time}.{self.ext}" + + @staticmethod + def from_path(path:PathLike, dataset:Dataset, session:Session) -> Optional['CameraView']: + """ Creates a camera view from a file path + + Args: + path (PathLike): Path to the camera view image file + dataset (Dataset): dataset this camera view belongs to + + Returns: + Optional[CameraView]: CameraView object if the path is valid, None otherwise + """ + result = None + + path = Path(path) + + camview_fn_match = camview_fn_re.match(path.stem) + if camview_fn_match is not None: + time = int(camview_fn_match.group('timestamp')) + cam_name = camview_fn_match.group('cam_name') + + timestamp = session.query(Timestamp).filter_by(time=time).first() + camera = session.query(Camera).filter_by(name=cam_name).first() + + if camera is not None and timestamp is not None: + result = CameraView(path.suffix[1:], dataset, timestamp, camera) + session.add(result) + session.commit() + + return result + diff --git a/util.py b/util.py new file mode 100644 index 0000000..39e2ec9 --- /dev/null +++ b/util.py @@ -0,0 +1,25 @@ +import json +from os import PathLike +from pathlib import Path + +from scipy.spatial.transform import RigidTransform, Rotation +import numpy as np + + +def relative_transform(world_to_a: RigidTransform, world_to_b: RigidTransform) -> RigidTransform: + """ Computes the relative transform between two poses represented by a transform from the same + world coordinates + + Args: + world_to_a (RigidTransform): transform from world origin to the first pose + world_to_b (RigidTransform): transform from world origin to the second pose + + Returns: + RigidTransform: transform to second pose from first pose b_to_a + """ + + Twa = world_to_a.as_matrix() + Tbw = np.linalg.inv(world_to_b.as_matrix()) + Tab = np.linalg.matmul(Tbw,Twa) + + return RigidTransform.from_matrix(Tab) \ No newline at end of file diff --git a/visualOdometry.py b/visualOdometry.py index 4d7bd90..6de9083 100644 --- a/visualOdometry.py +++ b/visualOdometry.py @@ -1,8 +1,16 @@ +from pathlib import Path from typing import Optional, Sequence import cv2 + import numpy as np from matplotlib import pyplot as plt +from scipy.spatial.transform import Rotation, RigidTransform +from sqlalchemy import create_engine, desc +from sqlalchemy.orm import Session + +from util import relative_transform +from data_parser import * class VisualOdometry: @@ -19,8 +27,8 @@ class VisualOdometry: search_params (dict[str, int], optional): Search parameters for FLANN. Defaults to {"checks": 50}. """ self.K = K - # pyright: ignore[reportAttributeAccessIssue] - self.sift = cv2.SIFT_create() + + self.sift = cv2.SIFT_create() # pyright: ignore[reportAttributeAccessIssue] self.flann = cv2.FlannBasedMatcher( indexParams=index_params, searchParams=search_params) # pyright: ignore[reportArgumentType] @@ -65,17 +73,34 @@ class VisualOdometry: """ return [m for m, n in matches if m.distance < distance_threshold * n.distance] - def estimate_motion(self, kp1: list[cv2.KeyPoint], kp2: list[cv2.KeyPoint], matches: list[cv2.DMatch]): - """ Estimates the motion between two images + def estimate_motion(self, kp1: list[cv2.KeyPoint], kp2: list[cv2.KeyPoint], matches: list[cv2.DMatch]) -> RigidTransform: + """_summary_ Args: - kp1 (list[cv2.KeyPoint]): first image keypoints - kp2 (list[cv2.KeyPoint]): second image keypoints - matches (list[cv2.DMatch]): list of keypoint matches + kp1 (list[cv2.KeyPoint]): _description_ + kp2 (list[cv2.KeyPoint]): _description_ + matches (list[cv2.DMatch]): _description_ + Returns: - TODO: Add returns + tuple[np.ndarray, np.ndarray]: _description_ """ + pts1 = np.float32([kp1[m.queryIdx].pt for m in matches]).reshape(-1, 1, 2) # pyright: ignore[reportArgumentType] + pts2 = np.float32([kp2[m.trainIdx].pt for m in matches]).reshape(-1, 1, 2) # pyright: ignore[reportArgumentType] + + E, _ = cv2.findEssentialMat( + points1=pts1, + points2=pts2, + cameraMatrix=self.K, + method=cv2.RANSAC, + prob=.999, + threshold=1.0) + + _, R, t, _ = cv2.recoverPose(E, pts1, pts2, cameraMatrix=self.K) + + return RigidTransform.from_components(translation=t.transpose(), rotation=Rotation.from_matrix(R)) + + def draw_keypoint_matches(self, img1: cv2.typing.MatLike, kp1: list[cv2.KeyPoint], @@ -98,9 +123,9 @@ class VisualOdometry: """ # Draw matches - # pyright: ignore[reportArgumentType, reportCallIssue] - return cv2.drawMatches(img1, kp1, img2, kp2, matches, output_image, flags=cv2.DrawMatchesFlags_NOT_DRAW_SINGLE_POINTS) - + + return cv2.drawMatches(img1, kp1, img2, kp2, matches, output_image, flags=cv2.DrawMatchesFlags_NOT_DRAW_SINGLE_POINTS) # pyright: ignore[reportArgumentType, reportCallIssue] + @staticmethod def show_keypoint_matches(match_image: cv2.typing.MatLike) -> None: """ Show image matches @@ -115,19 +140,29 @@ class VisualOdometry: plt.show() -def main(): - # Set Camera Intrinsics - K = np.array( - [[1389.2414846481593, 0, 962.3421649150145], - [0, 1389.2414846481593, 605.814069325842], - [0, 0, 1]], - dtype=np.float64) - # Set Image Paths - img1_path = ".\\train1\\3d20ae25-5b29-320d-8bae-f03e9dc177b9\\ring_front_center\\ring_front_center_315975023006264672.jpg" - img2_path = ".\\train1\\3d20ae25-5b29-320d-8bae-f03e9dc177b9\\ring_front_center\\ring_front_center_315975023039564872.jpg" +def main(): + # Create database for dataset + # TODO Move this to dataset library + engine = create_engine('sqlite:///:memory:') + create_tables(engine) + session = Session(bind=engine) + + # Import dataset + data_root = Path('./train1/3d20ae25-5b29-320d-8bae-f03e9dc177b9') + dataset = Dataset.import_dataset(data_root, session) + + # Get Camera + camera = session.query(Camera).filter_by(name='ring_front_center').first() + + if camera is None: + raise RuntimeError("Camera not found") # Load images + view1 = camera.camera_views[0] + view2 = camera.camera_views[1] + img1_path = view1.get_path() + img2_path = view2.get_path() img1 = cv2.imread(img1_path) img2 = cv2.imread(img2_path) @@ -138,7 +173,7 @@ def main(): raise RuntimeError(f"Could not open or find the image {img2_path}") # Create an instance of the VisualOdometry class - vo = VisualOdometry(K=K) + vo = VisualOdometry(K=camera.get_intrinsics()) # Extract Keypoints kp1, desc1 = vo.extract_keypoints(img1) @@ -156,6 +191,37 @@ def main(): # Show Matches VisualOdometry.show_keypoint_matches(img_matches) + # Estimate pose + T = vo.estimate_motion(kp1, kp2, good_matches) + + # Get Recoded Poses + T_wa = view1.timestamp.get_vehicle_pose() + T_wb = view2.timestamp.get_vehicle_pose() + T_ba = relative_transform(T_wa, T_wb) + + # Print results + t, R = T.as_components() + t_ba, R_ba = T_ba.as_components() + + print(f"Calculated Rotation matrix: \n{R.as_matrix()}") + print(f"Recorded Rotation matrix: \n{R_ba.as_matrix()}") + print(f"Difference Rotation matrix: \n{R_ba.as_matrix() - R.as_matrix()}") + print() + + print(f"Calculated Euler Angles: {R.as_euler('xyz')}") + print(f"Recorded Euler Angles: {R_ba.as_euler('xyz')}") + print(f"Difference Euler Angles: {R_ba.as_euler('xyz') - R.as_euler('xyz')}") + print() + + print(f"Calculated Quatrains: {R.as_quat()}") + print(f"Recorded Quatrains: {R_ba.as_quat()}") + print(f"Difference Quatrains: {R_ba.as_quat() - R.as_quat()}") + print() + + print(f"Calculated Translation: {t}") + print(f"Recorded Translation: {t_ba}") + print(f"Difference Translation: {t_ba - t}") + if __name__ == '__main__': main()