はむーんのブログ

データ分析のmemoとして活用したいなと。。。ラズパイ2を触った名残も掲載。。。

ロジスティック回帰①

ロジスティック回帰をやってみる系の投稿って多いけれど、
細かいこと気にすると、難しそうだよねって思ったので、
実際に確認してみた。

以下、サンプルが完全に分かれる場合。

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from sklearn.linear_model import LogisticRegression
from sklearn.cross_validation import train_test_split



def threshold_line(constant, coef, x):
    return -(constant+coef[0]*x)/coef[1]
    
def make_date():
# init value
    constant = 30.0
    coef = [12.0,  5.0]
    npar = len(coef)
    nsample = 8000

# make data
#    x = np.random.randn(npar, nsample)
    x = (np.random.random(npar*nsample)-0.5).reshape(npar,nsample)*5
    output = 1/(1 + np.exp(-(np.dot(coef,x) + constant)))
            # + np.random.uniform(-0.48, 0.48, nsample)
    output = output<=0.5

    return  (x, output)

if __name__=="__main__":
    x,output = make_date()
    X_train, X_test, y_train, y_test = train_test_split(x.T, output, test_size=0.5)
    lr = LogisticRegression(penalty='l1')
    lr.fit(X_train, y_train)
    
    X_plot=X_test
    Y_plot=y_test
    line_x = np.array([np.min(X_plot)-1, np.max(X_plot)+1])
    line_y = threshold_line(lr.intercept_, lr.coef_[0], line_x)
    plt.plot(X_plot[Y_plot>0.5, 0], X_plot[Y_plot>0.5, 1], 'xr')
    plt.plot(X_plot[Y_plot<0.5, 0], X_plot[Y_plot<0.5, 1], 'xb')
    plt.plot(line_x, line_y, 'k--')
    print(lr.score(X_plot,Y_plot))
    print(lr.score(X_train,y_train))
    print(lr.coef_)
    print(lr.intercept_)
    print(sum(Y_plot),Y_plot.shape[0]-sum(Y_plot))

    plt.axis([-3,3,-3,3])
    plt.show()

f:id:hmnmtn:20170812223235p:plain
上手くフィット出来ている。。。

以下は、constant(intercept)=39にしてみた結果。
f:id:hmnmtn:20170812223737p:plain
赤 16, 青 3984
になり、上手くフィットできない。。。
この例だと、1:100ぐらいのサンプル数になると怪しくなる。
サンプルによっては、1:10ぐらいから、怪しくなると思われる。(サンプルを正規分布にすると、変わった。)

サンプルが大きい青の方に合う様にフィットするために、
境界線が、赤のサンプル側に移動している。

class_weight='balanced'というオプションによって、
最尤法を使う際に、weightをかけて、フィットしてくれるみたいだが、上手くいかない。。。逆に青いサンプル側にずれる。

lr = LogisticRegression(penalty='l1',class_weight='balanced')

f:id:hmnmtn:20170812225050p:plain

ペナルティは、L1, L2。L2にすると、たぶん、正負のバランスが悪すぎると、効果が小さいので、係数決定が不安定になるのではないか。
また、class_weightは、過剰に効きすぎる傾向があると分かった。。。

ベイズとか使ってやると、サンプルが無いので上手く表現してくれるのか?
とか、思ってきたので、次に、やってみる。