Come esempio importiamo i dati relativi al GDP cinese corrispondente al reddito interno annuo dal 1960 al 2014, per farne poi una regressione
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline
df = pd.read_csv("china_gdp.csv")
df.head(10)
Visualizziamo i dati in un grafico
plt.figure(figsize=(8,5))
x_data, y_data = (df["Year"].values, df["Value"].values)
plt.plot(x_data, y_data, 'ro')
plt.ylabel('GDP')
plt.xlabel('Year')
plt.show()
Se ora vogliamo trovare la corretta interpolazione dobbiamo individuare il modello che meglio si adatta all'andamento dei dati sperimentali. Scegliamo la sigmoide adattabile con alcuni parametri: $\hat Y=\frac{1}{1+e^{\beta_1(X-\beta_2)}} $ Con $\beta_1$ parametro associato alla pendenza e $\beta_2$ associato alla traslazione sull'asse x
X = np.arange(-5.0, 5.0, 0.1)
Y = 1.0 / (1.0 + np.exp(-X))
plt.plot(X,Y)
plt.ylabel('Variabile Dipendente')
plt.xlabel('Variabile indipendente')
plt.show()
Costruiamo, dunque, il modello a partire dalla definizione dei parametri
def sigmoid(x, Beta_1, Beta_2):
y = 1 / (1 + np.exp(-Beta_1*(x-Beta_2)))
return y
che andrannoi adeguati empiricamente una volta sovrapposto il modello ai dati
beta_1 = 0.10
beta_2 = 1990.0
#logistic function
Y_pred = sigmoid(x_data, beta_1 , beta_2)
#plot initial prediction against datapoints
plt.plot(x_data, Y_pred*15000000000000.)
plt.plot(x_data, y_data, 'ro')
Poiché l'obiettivo è individuare i parametri, conviene per prima cosa normalizzare
# Lets normalize our data
xdata =x_data/max(x_data)
ydata =y_data/max(y_data)
La funzione curve_fit della libreria scipy permette di trovare il best fit
from scipy.optimize import curve_fit
popt, pcov = curve_fit(sigmoid, xdata, ydata)
#print the final parameters
print(" beta_1 = %f, beta_2 = %f" % (popt[0], popt[1]))
Ora possiamo confrontare il modello con i dati
x = np.linspace(1960, 2015, 55)
x = x/max(x)
plt.figure(figsize=(8,5))
y = sigmoid(x, *popt)
plt.plot(xdata, ydata, 'ro', label='data')
plt.plot(x,y, linewidth=3.0, label='fit')
plt.legend(loc='best')
plt.ylabel('GDP')
plt.xlabel('Year')
plt.show()