# Author: ChatGPT 5.0

import numpy as np
import matplotlib.pyplot as plt

def compute_error_matrix(x, y):
    """
    Precompute the squared error of the best-fit line for every interval [i, j].
    Returns a 2D array err[i][j] = minimum squared error for points i..j.
    """
    n = len(x)
    err = np.zeros((n, n))

    for i in range(n):
        for j in range(i, n):
            xi = x[i:j+1]
            yi = y[i:j+1]

            x_mean = np.mean(xi)
            y_mean = np.mean(yi)
            denom = np.sum((xi - x_mean) ** 2)

            # Avoid division by zero if xi are identical
            if denom == 0:
                slope = 0
            else:
                slope = np.sum((xi - x_mean) * (yi - y_mean)) / denom
            intercept = y_mean - slope * x_mean

            # Squared error for this segment
            y_pred = slope * xi + intercept
            err[i][j] = np.sum((yi - y_pred) ** 2)

    return err


def segmented_least_squares(x, y, C):
    """
    Dynamic Programming solution to the Segmented Least Squares problem.
    """
    n = len(x)
    err = compute_error_matrix(x, y)
    OPT = np.zeros(n)
    segment_index = np.zeros(n, dtype=int)

    # Dynamic Programming recurrence
    for j in range(n):
        min_cost = float('inf')
        best_i = 0
        for i in range(j + 1):
            cost = err[i][j] + C + (OPT[i - 1] if i > 0 else 0)
            if cost < min_cost:
                min_cost = cost
                best_i = i
        OPT[j] = min_cost
        segment_index[j] = best_i

    # Reconstruct the optimal segmentation
    segments = []
    j = n - 1
    while j >= 0:
        i = segment_index[j]
        segments.append((i, j))
        j = i - 1

    segments.reverse()
    return OPT, segments


# --- Example with Many Points ---
if __name__ == "__main__":
    np.random.seed(42)

    # Generate 60 data points (piecewise linear with noise)
    x = np.linspace(0, 30, 60)
    y = np.piecewise(
        x,
        [x < 10, (x >= 10) & (x < 20), x >= 20],
        [lambda x: 2*x + np.random.normal(0, 2, len(x)),
         lambda x: -x + 30 + np.random.normal(0, 2, len(x)),
         lambda x: 0.5*x + 5 + np.random.normal(0, 2, len(x))]
    )

    C = 40  # Penalty for introducing a new line segment

    OPT, segments = segmented_least_squares(x, y, C)

    print("Minimum total cost:", round(OPT[-1], 3))
    print("Optimal segments (start_index, end_index):", segments)

    # Plot the data points
    plt.scatter(x, y, color='blue', label='Data points', s=25)

    # Plot the fitted segments
    for (i, j) in segments:
        xi = x[i:j+1]
        yi = y[i:j+1]

        x_mean = np.mean(xi)
        y_mean = np.mean(yi)
        denom = np.sum((xi - x_mean) ** 2)
        slope = np.sum((xi - x_mean) * (yi - y_mean)) / denom if denom != 0 else 0
        intercept = y_mean - slope * x_mean

        y_pred = slope * xi + intercept
        plt.plot(xi, y_pred, color='red', linewidth=2.5)

    plt.title("Segmented Least Squares with Many Points")
    plt.xlabel("x")
    plt.ylabel("y")
    plt.legend()
    plt.grid(True)
    plt.show()