본문 바로가기

Data-science/machine learning

dtreeviz IndexError, decision tree visualization

728x90

 

random forest 모델 혹은 decision tree 모델을 이용했을 경우 모델이 어떻게 작동하는지 설명을 요구할 때가 있다.

RandomForestRegressor를 이용해서 회귀분석을 진행했다.

 

그런데 RandomForestRegressor같은 경우 여러 Decision tree의 앙상블 모델로 tree를 한 번에 시각화하기가 어렵다.

그래서 tree중 하나를 선택해서 시각화한다.

 

그러다 좋은 라이브러리를 발견했다.

dtreeviz 

이런식으로 쓰면되는 간단하고 편리한 라이브러리다!

model = Pipeline([('scaler',MinMaxScaler()), ('DecisionTreeRegressor', DecisionTreeRegressor(criterion='mae', max_depth=3, random_state=0))])

model.fit(X, y=Y)

from dtreeviz.trees import dtreeviz # remember to load the package

viz = dtreeviz(model[1], np.array(X), np.array(Y).reshape(-1),
              target_name='custom_target', feature_names=['a', 'b', 'c'])
viz.view()

여기서 model을 pipeline으로 정의해서 model[1]에  DecisionTreeRegressor가 담겨있다. 그래서  dtreeviz의 첫 인자의 model[1]을 넣었다.

그런데...!

 custom model, data로 하면 아래와 같은 에러가 계속 발생했다.

IndexError

 

dtreeviz github사이트에 들어가서 코드를 하나하나 뜯어 분석했다... 쓸데없이...

1시간 넘게 삽질했다.

해결했다. model을 pipeline으로 말고 그냥 model만으로 정의하니 되더라. 이유는 알 수 없다...

model = DecisionTreeRegressor(criterion='mae', max_depth=3, random_state=0)

model.fit(X, y=Y)

from dtreeviz.trees import dtreeviz # remember to load the package

viz = dtreeviz(model, np.array(X), np.array(Y).reshape(-1),
              target_name='custom_target', feature_names=['a', 'b', 'c'])
viz.view()

위 코드를 실행하면 새로운 브라우저가 뜨면서 아래 그림이 생긴다!

 

mljar.com/blog/visualize-decision-tree/

 

Visualize a Decision Tree in 4 Ways with Scikit-Learn and Python

A Decision Tree is a supervised algorithm used in machine learning. It is using a binary tree graph (each node has two children) to assign for each data sample a target value. The target values are presented in the tree leaves. To reach to the leaf, the sa

mljar.com

github.com/parrt/dtreeviz