Data-science/machine learning

dtreeviz IndexError, decision tree visualization

study&grow 2020. 11. 11. 17:05
728x90

 

random forest 모델 혹은 decision tree 모델을 이용했을 경우 모델이 어떻게 작동하는지 설명을 요구할 때가 있다.

RandomForestRegressor를 이용해서 회귀분석을 진행했다.

 

그런데 RandomForestRegressor같은 경우 여러 Decision tree의 앙상블 모델로 tree를 한 번에 시각화하기가 어렵다.

그래서 tree중 하나를 선택해서 시각화한다.

 

그러다 좋은 라이브러리를 발견했다.

dtreeviz 

이런식으로 쓰면되는 간단하고 편리한 라이브러리다!

model = Pipeline([('scaler',MinMaxScaler()), ('DecisionTreeRegressor', DecisionTreeRegressor(criterion='mae', max_depth=3, random_state=0))])

model.fit(X, y=Y)

from dtreeviz.trees import dtreeviz # remember to load the package

viz = dtreeviz(model[1], np.array(X), np.array(Y).reshape(-1),
              target_name='custom_target', feature_names=['a', 'b', 'c'])
viz.view()

여기서 model을 pipeline으로 정의해서 model[1]에  DecisionTreeRegressor가 담겨있다. 그래서  dtreeviz의 첫 인자의 model[1]을 넣었다.

그런데...!

 custom model, data로 하면 아래와 같은 에러가 계속 발생했다.

IndexError

 

dtreeviz github사이트에 들어가서 코드를 하나하나 뜯어 분석했다... 쓸데없이...

1시간 넘게 삽질했다.

해결했다. model을 pipeline으로 말고 그냥 model만으로 정의하니 되더라. 이유는 알 수 없다...

model = DecisionTreeRegressor(criterion='mae', max_depth=3, random_state=0)

model.fit(X, y=Y)

from dtreeviz.trees import dtreeviz # remember to load the package

viz = dtreeviz(model, np.array(X), np.array(Y).reshape(-1),
              target_name='custom_target', feature_names=['a', 'b', 'c'])
viz.view()

위 코드를 실행하면 새로운 브라우저가 뜨면서 아래 그림이 생긴다!

 

mljar.com/blog/visualize-decision-tree/

 

Visualize a Decision Tree in 4 Ways with Scikit-Learn and Python

A Decision Tree is a supervised algorithm used in machine learning. It is using a binary tree graph (each node has two children) to assign for each data sample a target value. The target values are presented in the tree leaves. To reach to the leaf, the sa

mljar.com

github.com/parrt/dtreeviz