本文介绍了Python scipy.optimise.curve_fit给出线性拟合的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

在使用scipy中的curve_fit参数时遇到了一个问题.我最初已经复制了文档建议的代码.然后,我稍稍更改了等式,这很好,但是增加了np.linspace,整个预测最终成为一条直线.有什么想法吗?

I have come across a problem when playing with the parameters of the curve_fit from scipy. I have initially copied the code suggested by the docs. I then changed the equation slightly and it was fine, but having increased the np.linspace, the whole prediction ended up being a straight line. Any ideas?

import numpy as np
from scipy.optimize import curve_fit
import matplotlib.pyplot as plt


def f(x, a, b, c):
    # This works fine on smaller numbers
    return (a - c) * np.exp(-x / b) + c


xdata = np.linspace(60, 3060, 200)
ydata = f(xdata, 100, 400, 20)

# noise
np.random.seed(1729)
ydata = ydata + np.random.normal(size=xdata.size) * 0.2

# graph
fig, ax = plt.subplots()
plt.plot(xdata, ydata, marker="o")
pred, covar = curve_fit(f, xdata, ydata)
plt.plot(xdata, f(xdata, *pred), label="prediciton")
plt.show()

推荐答案

以下是使用您的数据和方程式的示例代码,其初始参数估计值由scipy的differential_evolution遗传算法模块给出.该模块使用拉丁文Hypercube算法来确保对参数空间进行彻底搜索,这需要在搜索范围内进行.在此示例中,这些界限取自数据的最大值和最小值.为初始参数估计值提供范围比提供特定值要容易得多.

Here is example code using your data and equation, with the initial parameter estimates given by scipy's differential_evolution genetic algorithm module. That module uses the Latin Hypercube algorithm to ensure a thorough search of parameter space, which requires bounds within which to search. In this example those bounds are taken from the data maximum and minimum values. It is much easier to supply ranges for the initial parameter estimates rather than specific values.

import numpy, scipy, matplotlib
import matplotlib.pyplot as plt
from scipy.optimize import curve_fit
from scipy.optimize import differential_evolution
import warnings


def func(x, a, b, c):
    return (a - c) * numpy.exp(-x / b) + c


xData = numpy.linspace(60, 3060, 200)
yData = func(xData, 100, 400, 20)

# noise
numpy.random.seed(1729)
yData = yData + numpy.random.normal(size=xData.size) * 0.2


# function for genetic algorithm to minimize (sum of squared error)
def sumOfSquaredError(parameterTuple):
    warnings.filterwarnings("ignore") # do not print warnings by genetic algorithm
    val = func(xData, *parameterTuple)
    return numpy.sum((yData - val) ** 2.0)


def generate_Initial_Parameters():
    # min and max used for bounds
    maxX = max(xData)
    minX = min(xData)
    maxY = max(yData)
    minY = min(yData)

    parameterBounds = []
    parameterBounds.append([minY, maxY]) # search bounds for a
    parameterBounds.append([minX, maxX]) # search bounds for b
    parameterBounds.append([minY, maxY]) # search bounds for c

    # "seed" the numpy random number generator for repeatable results
    result = differential_evolution(sumOfSquaredError, parameterBounds, seed=3)
    return result.x

# by default, differential_evolution completes by calling curve_fit() using parameter bounds
geneticParameters = generate_Initial_Parameters()

# now call curve_fit without passing bounds from the genetic algorithm,
# just in case the best fit parameters are aoutside those bounds
fittedParameters, pcov = curve_fit(func, xData, yData, geneticParameters)
print('Fitted parameters:', fittedParameters)
print()

modelPredictions = func(xData, *fittedParameters)

absError = modelPredictions - yData

SE = numpy.square(absError) # squared errors
MSE = numpy.mean(SE) # mean squared errors
RMSE = numpy.sqrt(MSE) # Root Mean Squared Error, RMSE
Rsquared = 1.0 - (numpy.var(absError) / numpy.var(yData))

print()
print('RMSE:', RMSE)
print('R-squared:', Rsquared)

print()


##########################################################
# graphics output section
def ModelAndScatterPlot(graphWidth, graphHeight):
    f = plt.figure(figsize=(graphWidth/100.0, graphHeight/100.0), dpi=100)
    axes = f.add_subplot(111)

    # first the raw data as a scatter plot
    axes.plot(xData, yData,  'D')

    # create data for the fitted equation plot
    xModel = numpy.linspace(min(xData), max(xData))
    yModel = func(xModel, *fittedParameters)

    # now the model as a line plot
    axes.plot(xModel, yModel)

    axes.set_xlabel('X Data') # X axis data label
    axes.set_ylabel('Y Data') # Y axis data label

    plt.show()
    plt.close('all') # clean up after using pyplot

graphWidth = 800
graphHeight = 600
ModelAndScatterPlot(graphWidth, graphHeight)

这篇关于Python scipy.optimise.curve_fit给出线性拟合的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持!

08-20 03:09