import numpy as np
import matplotlib.pyplot as plt
def compute_error_matrix(x, y):
"""
Precompute the squared error of the best-fit line for every interval [i, j].
Returns a 2D array err[i][j] = minimum squared error for points i..j.
"""
n = len(x)
err = np.zeros((n, n))
for i in range(n):
for j in range(i, n):
xi = x[i:j+1]
yi = y[i:j+1]
x_mean = np.mean(xi)
y_mean = np.mean(yi)
denom = np.sum((xi - x_mean) ** 2)
if denom == 0:
slope = 0
else:
slope = np.sum((xi - x_mean) * (yi - y_mean)) / denom
intercept = y_mean - slope * x_mean
y_pred = slope * xi + intercept
err[i][j] = np.sum((yi - y_pred) ** 2)
return err
def segmented_least_squares(x, y, C):
"""
Dynamic Programming solution to the Segmented Least Squares problem.
"""
n = len(x)
err = compute_error_matrix(x, y)
OPT = np.zeros(n)
segment_index = np.zeros(n, dtype=int)
for j in range(n):
min_cost = float('inf')
best_i = 0
for i in range(j + 1):
cost = err[i][j] + C + (OPT[i - 1] if i > 0 else 0)
if cost < min_cost:
min_cost = cost
best_i = i
OPT[j] = min_cost
segment_index[j] = best_i
segments = []
j = n - 1
while j >= 0:
i = segment_index[j]
segments.append((i, j))
j = i - 1
segments.reverse()
return OPT, segments
if __name__ == "__main__":
np.random.seed(42)
x = np.linspace(0, 30, 60)
y = np.piecewise(
x,
[x < 10, (x >= 10) & (x < 20), x >= 20],
[lambda x: 2*x + np.random.normal(0, 2, len(x)),
lambda x: -x + 30 + np.random.normal(0, 2, len(x)),
lambda x: 0.5*x + 5 + np.random.normal(0, 2, len(x))]
)
C = 40
OPT, segments = segmented_least_squares(x, y, C)
print("Minimum total cost:", round(OPT[-1], 3))
print("Optimal segments (start_index, end_index):", segments)
plt.scatter(x, y, color='blue', label='Data points', s=25)
for (i, j) in segments:
xi = x[i:j+1]
yi = y[i:j+1]
x_mean = np.mean(xi)
y_mean = np.mean(yi)
denom = np.sum((xi - x_mean) ** 2)
slope = np.sum((xi - x_mean) * (yi - y_mean)) / denom if denom != 0 else 0
intercept = y_mean - slope * x_mean
y_pred = slope * xi + intercept
plt.plot(xi, y_pred, color='red', linewidth=2.5)
plt.title("Segmented Least Squares with Many Points")
plt.xlabel("x")
plt.ylabel("y")
plt.legend()
plt.grid(True)
plt.show()