사이킷런의 프레임워크와 연동할 수 있는 전용 XGBoost 래퍼 클래스에는 분류용 XGBoostClassifier, 회귀용 XGBoostRegressor이 있습니다. 래퍼 클래스는 다음과 같은 장점을 가지고 있습니다.
- 사이킷런의 기본 estimator를 그대로 상속해 만들었기 때문에 fit()과 predict()만으로 학습과 예측이 가능합니다.
- GridSearchCV, Pipeline 등 다른 사이킷런의 다른 유틸리티를 그대로 함께 사용할 수 있습니다.
- 기존의 다른 프로그램의 알고리즘으로 XGBoost 래퍼 클래스를 사용할 수도 있습니다.
https://smartest-suri.tistory.com/40
이전 글에서 학습했던 기본 XGBoost API 대신 사이킷런 연동 XGBoost 래퍼 클래스 XGBoostClassifier를 사용해 모델을 학습시키고 예측을 수행해 보겠습니다.
호출 및 hyperparameter
from xgboost import XGBClassifier
xgb_wrapper = XGBClassifier(n_estimators = 400, # num_boost_round -> n_estimators
learning_rate = 0.05, # eta -> learning_rate
max_depth = 3,
eval_metric = 'logloss')
learning_rate와 같이 기존 사이킷런 하이퍼 파라미터와의 호환성 유지를 위해 변경된 하이퍼 파라미터들이 있으므로 유의합니다.
학습 및 예측
fit()과 predict() 메소드를 이용해서 모델을 학습시키고 예측을 수행해 보겠습니다.
# 학습
xgb_wrapper.fit(X_train, y_train, verbose = True)
# 예측
w_preds = xgb_wrapper.predict(X_test)
w_preds[:10]
# array([1, 0, 1, 0, 1, 1, 1, 1, 1, 0])
w_pred_proba = xgb_wrapper.predict_proba(X_test)[:, 1]
w_pred_proba[:10]
# array([8.8368094e-01, 2.7957589e-03, 8.9875424e-01, 1.8748048e-01,
# 9.9204481e-01, 9.9990714e-01, 9.9954444e-01, 9.9904817e-01,
# 9.9527210e-01, 1.9664205e-04], dtype=float32)
아주 빠르고 간단하게 학습과 예측을 수행했습니다.
평가
이전 포스팅에서 작성해 둔 get_clf_eval() 함수를 이용해서 사이킷런 래퍼 XGBoost로 만들어진 모델의 예측 성능 평가를 해 보겠습니다.
get_clf_eval(y_test, w_preds, w_pred_proba)
지표 | 이전 | 이후 |
Accuracy | 약 0.96 | 약 0.97 |
Precision | 약 0.97 | 약 0.97 |
Recall | 약 0.97 | 약 0.98 |
F1-Score | 약 0.97 | 약 0.98 |
ROC-AUC | 약 0.99 | 약 0.99 |
이전 실습보다 평가 지표가 조금 상승했습니다. 이번 실습에서는 early stopping을 따로 설정하지 않은 관계로 train 데이터를 train과 valid 데이터로 나누는 과정을 생략하였고, 그래서 트레인 데이터셋의 수가 늘어난 영향이 있을 것으로 파악됩니다. (애초에 트레이닝 데이터가 풍부한 데이터셋은 아닌 관계로)
여기까지 XGBoost 관련 두 번째 실습을 마쳐봅니다. 감사합니다 :-)
'Data Science > ML 머신러닝' 카테고리의 다른 글
ML | 캐글 Kaggle 신용카드 데이터 EDA + 모델링 실습 (0) | 2024.05.10 |
---|---|
ML | 파이썬 XGBoost API 사용하여 위스콘신 유방암 예측하기 (0) | 2024.05.08 |
EDA | 서울특별시 공중화장실 02 _ 태블로를 이용해 시각화하기 (0) | 2024.04.17 |
EDA | 서울특별시 공중화장실 01 _ pandas를 이용한 공공데이터 정제, 전처리하기 (1) | 2024.04.16 |
웹크롤링 | 연금복권720+ 당첨 데이터 분석해보기 (파이썬 requests, BeautifulSoup) (1) | 2024.04.13 |