Skip to content
Snippets Groups Projects
Commit ec91be2d authored by Amos's avatar Amos
Browse files

optical modified

parent a3e7c06d
No related branches found
No related tags found
No related merge requests found
......@@ -182,43 +182,25 @@ def draw_function(
import numpy as np
from scipy.signal import medfilt
def local_shift_computation(curve1, curve2, max_shift=50, window_radius=5):
def local_shift_computation(curve1, curve2, max_shift=20, window_radius=5):
"""
Computes the local shift between two curves using numpy operations.
curve1: np.array, the first curve (reference)
curve2: np.array, the second curve (to be shifted)
max_shift: int, the maximum allowed shift to consider
window_radius: int, the radius of the local window for similarity computation
Returns:
shifts: np.array, the shift values for each point in the curve
优化局部位移计算算法
"""
length = min(len(curve1), len(curve2))
curve1 = curve1[:length]
curve2 = curve2[:length]
# Initialize shifts
shifts = np.zeros(length)
# Iterate over each pixel position to compute the local shift
for x in range(window_radius, length - window_radius):
min_error = float('inf')
best_shift = 0
# Calculate the error for each possible shift
for s in range(0, max_shift + 1):
error = np.sum((curve1[x - window_radius: x + window_radius + 1] -
np.roll(curve2, s)[x - window_radius: x + window_radius + 1]) ** 2)
if error < min_error:
min_error = error
best_shift = s
shifts[x] = best_shift
# Apply median filter to smooth the shifts
smoothed_shifts = medfilt(shifts, kernel_size=21)
# 初始化误差矩阵
error_matrix = np.full((length, max_shift + 1), float('inf'))
for s in range(0, max_shift + 1):
shifted_curve2 = np.roll(curve2, s)
local_diff = (curve1 - shifted_curve2) ** 2
error_matrix[:, s] = np.convolve(local_diff, np.ones((2 * window_radius + 1,)), mode='same')
best_shifts = np.argmin(error_matrix, axis=1)
smoothed_shifts = medfilt(best_shifts, kernel_size=21)
return smoothed_shifts
......
......@@ -6,7 +6,7 @@ from sensor_msgs.msg import CompressedImage
import cv2
import numpy as np
from cv_bridge import CvBridge
from vanishing_point.my_line_library import local_shift_computation, plot_shift_on_image
from vanishing_point.my_line_library import local_shift_computation, draw_function
class OpticalFlowNode(Node):
def __init__(self):
......@@ -33,7 +33,7 @@ class OpticalFlowNode(Node):
# 计算局部位移
shifts = local_shift_computation(curve1, curve2)
plot_shift_on_image(image, shifts)
draw_function(image, shifts, hmin=0, hmax=image.shape[0], ymin=-50, ymax=50, color=(255, 0, 0), thickness=1)
# 显示图像
cv2.imshow("Optical Flow Image", image)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment