diff --git a/navsim/planning/simulation/planner/pdm_planner/simulation/batch_lqr_utils.py b/navsim/planning/simulation/planner/pdm_planner/simulation/batch_lqr_utils.py index 408ceef..adf6cd4 100644 --- a/navsim/planning/simulation/planner/pdm_planner/simulation/batch_lqr_utils.py +++ b/navsim/planning/simulation/planner/pdm_planner/simulation/batch_lqr_utils.py @@ -1,3 +1,4 @@ +import torch from typing import Tuple import numpy as np @@ -119,9 +120,13 @@ def _fit_initial_velocity_and_acceleration_profile( A_T, R_T = np.transpose(A, (0, 2, 1)), np.transpose(R, (0, 2, 1)) + # Convert A and R to PyTorch tensors + A_tensor = torch.tensor(batch_matmul(A_T, A), dtype=torch.float32) + R_tensor = torch.tensor(batch_matmul(R_T, R), dtype=torch.float32) + # Compute regularized least squares solution. intermediate_solution = batch_matmul( - np.linalg.pinv(batch_matmul(A_T, A) + jerk_penalty * batch_matmul(R_T, R)), A_T + torch.linalg.inv(A_tensor + jerk_penalty * R_tensor), A_T ) x = np.einsum("bij, bj -> bi", intermediate_solution, y) @@ -174,9 +179,14 @@ def _fit_initial_curvature_and_curvature_rate_profile( Q[0, 0] = initial_curvature_penalty # Compute regularized least squares solution. - A_T = A.transpose(0, 2, 1) + A_T = A.transpose(0, 2, 1) # 确保 A_T 在这里定义 + + # Convert A and Q to PyTorch tensors + A_tensor = torch.tensor(batch_matmul(A_T, A), dtype=torch.float32) + Q_tensor = torch.tensor(Q, dtype=torch.float32) # Convert Q to tensor - intermediate = batch_matmul(np.linalg.pinv(batch_matmul(A_T, A) + Q), A_T) + # Compute regularized least squares solution. + intermediate = batch_matmul(torch.linalg.inv(A_tensor + Q_tensor), A_T) # Use Q_tensor x = np.einsum("bij,bj->bi", intermediate, y) # Extract profile from solution.