6. Forests and Trees#

Ensemble methods combine predictions across multiple models to improve prediction.

  • weak learners - simple models that predict slightly better than random guessing

  • strong learners - more complex models that predict significantly better than chance

6.1. Random Forests (RandomForestClassifier)#

A random forest creates numerous trees in parallel (at the same time). Each tree is relatively shallow, sometimes only a single node (called a stump).

To make a prediction, a sample is processed by each tree, each tree makes a prediction, and the majority vote wins. But what makes the trees in the forest different from each other?

In fitting, diversity of trees is created using two methods: feature selection and bagging.

6.1.1. Random feature subsets#

Each tree only gets a subset of the features. For example, in a random forest deciding whether or not you should buy a car, one tree might make a prediction based on [‘reliability’, ‘feul economy’, ‘price’], while another uses [‘top speed’, ‘interior room’, ‘cost to repair’], and another uses [‘resale value’, ‘feul economy’, ‘value of standard tech’] and another…

6.1.2. Bagging (Bootstrap aggregating)#

Bootstrapping is a method for creating new data sets by sampling existing data sets. In bootstrapping, you select samples randomly and allow a sample to be selected multiple times (called sampling with replacement).

Bagging is a method that uses bootstrapping to create different training sets for each tree, and then aggregating the results.

6.1.3. Why it works?#

The idea is that no one tree will be great, but they’ll all make different mistakes. But there’ll be more overlap in correct guesses than in mistakes. So for any given sample, the majority vote is more likely to be correct than any one tree.

6.2. Gradient Boosted Trees#

Whereas Random Forests fit trees in parallel and every tree gets an equal vote, Boosted trees create trees sequentially, each new tree focusing on the shortcomings of the previous. And at the voting stage, some trees get more say than others.

There are many flavors of Boosted trees: AdaBoost, XGBoost, CatBoost

They all work a little differently, but here’s an outline of AdaBoost as an example:

6.2.1. AdaBoost (Adaptive Boosting) ( GradientBoostingClassifier)#

In AdaBoost, a tree comprises only one decision node; this kind of tree is called a stump. In each iteration, a new stump is created that splits the data based on a different condition. As the algorithm iterates, it keeps track of:

  • Sample Weight - each iteration, the algorithm focuses more on misclassified samples.

    • A sample that is classified correctly is down-weighted. We get this right, don’t spend more energy on this case.

    • A sample that is classified incorrectly is up-weighted. We get this wrong, focus on this case.

  • Tree Influence - how much say a tree will have in the final vote. Trees that do better at classifying get more say.

    • A tree that is 50% correct gets no say. This tree is just guessing

    • A tree that is >50% gets a positive vote (0 to infinity). A tree that is 100% correct gets infinite vote! Listen to that tree!

    • A tree that is <50% gets a negative vote (0 to -infinity). A tree that is 0% correct gets a -infinite vote! Do the opposite of that tree!

The AdaBoost process:

  1. Start with all the samples each counts the same.

  2. Same as in a decision tree, pick a question that splits the data to minimize Gini Impurity.

  3. Sum up sample weights for mis-classified samples and calculate Tree Influence.

  4. Assign new weights to samples, increasing weights on mistakes and decreasing weights on correct classifications.

  5. Create new stump, and repeat 2-5 until classification error is below some threshold you choose.

When you predict, you feed the sample through all the stumps and each votes according to their influence.

6.2.2. Example: Spam email prediction#

# !pip install ucimlrepo
from ucimlrepo import fetch_ucirepo 
import pandas as pd
  
# fetch dataset 
spambase = fetch_ucirepo(id=94) 
  
# data (as pandas dataframes) 
X = spambase.data.features 
Y = spambase.data.targets 

# to make y a compatible shape for sklearn models
y = Y['Class']
labels = ['ham', 'spam']

# metadata 
print(spambase.metadata) 
  
# variable information 
print(spambase.variables) 
{'uci_id': 94, 'name': 'Spambase', 'repository_url': 'https://archive.ics.uci.edu/dataset/94/spambase', 'data_url': 'https://archive.ics.uci.edu/static/public/94/data.csv', 'abstract': 'Classifying Email as Spam or Non-Spam', 'area': 'Computer Science', 'tasks': ['Classification'], 'characteristics': ['Multivariate'], 'num_instances': 4601, 'num_features': 57, 'feature_types': ['Integer', 'Real'], 'demographics': [], 'target_col': ['Class'], 'index_col': None, 'has_missing_values': 'no', 'missing_values_symbol': None, 'year_of_dataset_creation': 1999, 'last_updated': 'Mon Aug 28 2023', 'dataset_doi': '10.24432/C53G6X', 'creators': ['Mark Hopkins', 'Erik Reeber', 'George Forman', 'Jaap Suermondt'], 'intro_paper': None, 'additional_info': {'summary': 'The "spam" concept is diverse: advertisements for products/web sites, make money fast schemes, chain letters, pornography...\n\nThe classification task for this dataset is to determine whether a given email is spam or not.\n\t\nOur collection of spam e-mails came from our postmaster and individuals who had filed spam.  Our collection of non-spam e-mails came from filed work and personal e-mails, and hence the word \'george\' and the area code \'650\' are indicators of non-spam.  These are useful when constructing a personalized spam filter.  One would either have to blind such non-spam indicators or get a very wide collection of non-spam to generate a general purpose spam filter.\n\nFor background on spam: Cranor, Lorrie F., LaMacchia, Brian A.  Spam!, Communications of the ACM, 41(8):74-83, 1998.\n\nTypical performance is around ~7% misclassification error. False positives (marking good mail as spam) are very undesirable.If we insist on zero false positives in the training/testing set, 20-25% of the spam passed through the filter. See also Hewlett-Packard Internal-only Technical Report. External version forthcoming. ', 'purpose': None, 'funded_by': None, 'instances_represent': 'Emails', 'recommended_data_splits': None, 'sensitive_data': None, 'preprocessing_description': None, 'variable_info': 'The last column of \'spambase.data\' denotes whether the e-mail was considered spam (1) or not (0), i.e. unsolicited commercial e-mail.  Most of the attributes indicate whether a particular word or character was frequently occuring in the e-mail.  The run-length attributes (55-57) measure the length of sequences of consecutive capital letters.  For the statistical measures of each attribute, see the end of this file.  Here are the definitions of the attributes:\r\n\r\n48 continuous real [0,100] attributes of type word_freq_WORD \r\n= percentage of words in the e-mail that match WORD, i.e. 100 * (number of times the WORD appears in the e-mail) / total number of words in e-mail.  A "word" in this case is any string of alphanumeric characters bounded by non-alphanumeric characters or end-of-string.\r\n\r\n6 continuous real [0,100] attributes of type char_freq_CHAR] \r\n= percentage of characters in the e-mail that match CHAR, i.e. 100 * (number of CHAR occurences) / total characters in e-mail\r\n\r\n1 continuous real [1,...] attribute of type capital_run_length_average \r\n= average length of uninterrupted sequences of capital letters\r\n\r\n1 continuous integer [1,...] attribute of type capital_run_length_longest \r\n= length of longest uninterrupted sequence of capital letters\r\n\r\n1 continuous integer [1,...] attribute of type capital_run_length_total \r\n= sum of length of uninterrupted sequences of capital letters \r\n= total number of capital letters in the e-mail\r\n\r\n1 nominal {0,1} class attribute of type spam\r\n= denotes whether the e-mail was considered spam (1) or not (0), i.e. unsolicited commercial e-mail.  \r\n', 'citation': None}}
                          name     role        type demographic  \
0               word_freq_make  Feature  Continuous        None   
1            word_freq_address  Feature  Continuous        None   
2                word_freq_all  Feature  Continuous        None   
3                 word_freq_3d  Feature  Continuous        None   
4                word_freq_our  Feature  Continuous        None   
5               word_freq_over  Feature  Continuous        None   
6             word_freq_remove  Feature  Continuous        None   
7           word_freq_internet  Feature  Continuous        None   
8              word_freq_order  Feature  Continuous        None   
9               word_freq_mail  Feature  Continuous        None   
10           word_freq_receive  Feature  Continuous        None   
11              word_freq_will  Feature  Continuous        None   
12            word_freq_people  Feature  Continuous        None   
13            word_freq_report  Feature  Continuous        None   
14         word_freq_addresses  Feature  Continuous        None   
15              word_freq_free  Feature  Continuous        None   
16          word_freq_business  Feature  Continuous        None   
17             word_freq_email  Feature  Continuous        None   
18               word_freq_you  Feature  Continuous        None   
19            word_freq_credit  Feature  Continuous        None   
20              word_freq_your  Feature  Continuous        None   
21              word_freq_font  Feature  Continuous        None   
22               word_freq_000  Feature  Continuous        None   
23             word_freq_money  Feature  Continuous        None   
24                word_freq_hp  Feature  Continuous        None   
25               word_freq_hpl  Feature  Continuous        None   
26            word_freq_george  Feature  Continuous        None   
27               word_freq_650  Feature  Continuous        None   
28               word_freq_lab  Feature  Continuous        None   
29              word_freq_labs  Feature  Continuous        None   
30            word_freq_telnet  Feature  Continuous        None   
31               word_freq_857  Feature  Continuous        None   
32              word_freq_data  Feature  Continuous        None   
33               word_freq_415  Feature  Continuous        None   
34                word_freq_85  Feature  Continuous        None   
35        word_freq_technology  Feature  Continuous        None   
36              word_freq_1999  Feature  Continuous        None   
37             word_freq_parts  Feature  Continuous        None   
38                word_freq_pm  Feature  Continuous        None   
39            word_freq_direct  Feature  Continuous        None   
40                word_freq_cs  Feature  Continuous        None   
41           word_freq_meeting  Feature  Continuous        None   
42          word_freq_original  Feature  Continuous        None   
43           word_freq_project  Feature  Continuous        None   
44                word_freq_re  Feature  Continuous        None   
45               word_freq_edu  Feature  Continuous        None   
46             word_freq_table  Feature  Continuous        None   
47        word_freq_conference  Feature  Continuous        None   
48                 char_freq_;  Feature  Continuous        None   
49                 char_freq_(  Feature  Continuous        None   
50                 char_freq_[  Feature  Continuous        None   
51                 char_freq_!  Feature  Continuous        None   
52                 char_freq_$  Feature  Continuous        None   
53                 char_freq_#  Feature  Continuous        None   
54  capital_run_length_average  Feature  Continuous        None   
55  capital_run_length_longest  Feature  Continuous        None   
56    capital_run_length_total  Feature  Continuous        None   
57                       Class   Target      Binary        None   

                 description units missing_values  
0                       None  None             no  
1                       None  None             no  
2                       None  None             no  
3                       None  None             no  
4                       None  None             no  
5                       None  None             no  
6                       None  None             no  
7                       None  None             no  
8                       None  None             no  
9                       None  None             no  
10                      None  None             no  
11                      None  None             no  
12                      None  None             no  
13                      None  None             no  
14                      None  None             no  
15                      None  None             no  
16                      None  None             no  
17                      None  None             no  
18                      None  None             no  
19                      None  None             no  
20                      None  None             no  
21                      None  None             no  
22                      None  None             no  
23                      None  None             no  
24                      None  None             no  
25                      None  None             no  
26                      None  None             no  
27                      None  None             no  
28                      None  None             no  
29                      None  None             no  
30                      None  None             no  
31                      None  None             no  
32                      None  None             no  
33                      None  None             no  
34                      None  None             no  
35                      None  None             no  
36                      None  None             no  
37                      None  None             no  
38                      None  None             no  
39                      None  None             no  
40                      None  None             no  
41                      None  None             no  
42                      None  None             no  
43                      None  None             no  
44                      None  None             no  
45                      None  None             no  
46                      None  None             no  
47                      None  None             no  
48                      None  None             no  
49                      None  None             no  
50                      None  None             no  
51                      None  None             no  
52                      None  None             no  
53                      None  None             no  
54                      None  None             no  
55                      None  None             no  
56                      None  None             no  
57  spam (1) or not spam (0)  None             no  
X.describe()
word_freq_make word_freq_address word_freq_all word_freq_3d word_freq_our word_freq_over word_freq_remove word_freq_internet word_freq_order word_freq_mail ... word_freq_conference char_freq_; char_freq_( char_freq_[ char_freq_! char_freq_$ char_freq_# capital_run_length_average capital_run_length_longest capital_run_length_total
count 4601.000000 4601.000000 4601.000000 4601.000000 4601.000000 4601.000000 4601.000000 4601.000000 4601.000000 4601.000000 ... 4601.000000 4601.000000 4601.000000 4601.000000 4601.000000 4601.000000 4601.000000 4601.000000 4601.000000 4601.000000
mean 0.104553 0.213015 0.280656 0.065425 0.312223 0.095901 0.114208 0.105295 0.090067 0.239413 ... 0.031869 0.038575 0.139030 0.016976 0.269071 0.075811 0.044238 5.191515 52.172789 283.289285
std 0.305358 1.290575 0.504143 1.395151 0.672513 0.273824 0.391441 0.401071 0.278616 0.644755 ... 0.285735 0.243471 0.270355 0.109394 0.815672 0.245882 0.429342 31.729449 194.891310 606.347851
min 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 ... 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 1.000000 1.000000 1.000000
25% 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 ... 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 1.588000 6.000000 35.000000
50% 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 ... 0.000000 0.000000 0.065000 0.000000 0.000000 0.000000 0.000000 2.276000 15.000000 95.000000
75% 0.000000 0.000000 0.420000 0.000000 0.380000 0.000000 0.000000 0.000000 0.000000 0.160000 ... 0.000000 0.000000 0.188000 0.000000 0.315000 0.052000 0.000000 3.706000 43.000000 266.000000
max 4.540000 14.280000 5.100000 42.810000 10.000000 5.880000 7.270000 11.110000 5.260000 18.180000 ... 10.000000 4.385000 9.752000 4.081000 32.478000 6.003000 19.829000 1102.500000 9989.000000 15841.000000

8 rows × 57 columns

6.2.3. Example: Palmer Penguins#

# palmer = pd.read_csv('https://gist.githubusercontent.com/slopp/ce3b90b9168f2f921784de84fa445651/raw/4ecf3041f0ed4913e7c230758733948bc561f434/penguins.csv', index_col = 'rowid')

# palmer.dropna(axis = 0, inplace=True)
# palmer.reset_index(drop = True, inplace=True)

# features = ['bill_length_mm', 'bill_depth_mm',
#        'flipper_length_mm', 'body_mass_g']

# target = 'species'
# labels = ['Adelie', 'Chinstrap', 'Gentoo']

# X = palmer[features]
# y = palmer[target]
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
from sklearn.model_selection import GridSearchCV, train_test_split


# Split data
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5, random_state=42)

# Decision Tree
dt_params = {'max_depth': [3, 6, 9, None], 'min_samples_split': [2, 5, 10]}
dt_grid = GridSearchCV(DecisionTreeClassifier(random_state=42), dt_params, cv=5, n_jobs=-1)
dt_grid.fit(X_train, y_train)

# Random Forest
rf_params = {'n_estimators': [10, 30, 100], 'max_depth': [1, 2]}
rf_grid = GridSearchCV(RandomForestClassifier(random_state=42), rf_params, cv=5, n_jobs=-1)
rf_grid.fit(X_train, y_train)

# Gradient Boosted Trees
gb_params = {'n_estimators': [10, 30, 100], 'max_depth': [1, 2]}
gb_grid = GridSearchCV(GradientBoostingClassifier(random_state=42), gb_params, cv=5, n_jobs=-1)
gb_grid.fit(X_train, y_train)


# Get the best models
print("Best Decision Tree params:", dt_grid.best_params_)
print("Best Random Forest params:", rf_grid.best_params_)
print("Best Gradient Boosted params:", gb_grid.best_params_)

tree = dt_grid.best_estimator_
forest = rf_grid.best_estimator_
boosted = gb_grid.best_estimator_
Best Decision Tree params: {'max_depth': 6, 'min_samples_split': 10}
Best Random Forest params: {'max_depth': 2, 'n_estimators': 100}
Best Gradient Boosted params: {'max_depth': 2, 'n_estimators': 100}
y_tree_train = tree.predict(X_train)
y_tree_test = tree.predict(X_test)

y_forest_train = forest.predict(X_train)
y_forest_test = forest.predict(X_test)

y_boosted_train = boosted.predict(X_train)
y_boosted_test = boosted.predict(X_test)
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay, classification_report

import matplotlib.pyplot as plt

models = [
    ('Decision Tree', y_tree_train, y_tree_test),
    ('Random Forest', y_forest_train, y_forest_test),
    ('Gradient Boosted', y_boosted_train, y_boosted_test)
]

fig, axes = plt.subplots(3, 2, figsize=(8, 9))

for i, (name, y_pred_train, y_pred_test) in enumerate(models):
    cm_train = confusion_matrix(y_train, y_pred_train)
    cm_test = confusion_matrix(y_test, y_pred_test)

    disp_train = ConfusionMatrixDisplay(cm_train, display_labels=labels)
    disp_test = ConfusionMatrixDisplay(cm_test)

    disp_train.plot(ax=axes[i, 0], cmap='Blues', values_format='d')
    axes[i, 0].set_title(f'{name} - Train')

    disp_test.plot(ax=axes[i, 1], cmap='Blues', values_format='d')
    axes[i, 1].set_title(f'{name} - Test')

plt.tight_layout()
../_images/90d5b6c7d339ad08ef7fa1d751fd96e96a00c8a47e5dd49cc9ca2cf79eaf5bbc.png
import numpy as np

def display_feature_importance(model):
    imps = model.feature_importances_
    features = model.feature_names_in_
    
    sort_idx = np.argsort(imps)[::-1]
    imps = imps[sort_idx]
    features = features[sort_idx]
    
    for k, (feature, imp) in enumerate(zip(features, imps), start = 1):
        print(f'{k:>3}. {feature:_<30}{imp:.4f}')
        
print('\nDECISION TREE\n====================================')
display_feature_importance(tree)
DECISION TREE
====================================
  1. char_freq_$___________________0.4501
  2. word_freq_remove______________0.2068
  3. char_freq_!___________________0.0992
  4. word_freq_hp__________________0.0598
  5. capital_run_length_total______0.0470
  6. word_freq_free________________0.0320
  7. word_freq_edu_________________0.0168
  8. word_freq_george______________0.0163
  9. word_freq_000_________________0.0162
 10. capital_run_length_longest____0.0110
 11. word_freq_you_________________0.0108
 12. capital_run_length_average____0.0089
 13. word_freq_1999________________0.0081
 14. word_freq_hpl_________________0.0066
 15. word_freq_email_______________0.0038
 16. word_freq_over________________0.0023
 17. word_freq_conference__________0.0022
 18. word_freq_mail________________0.0022
 19. word_freq_font________________0.0000
 20. word_freq_business____________0.0000
 21. word_freq_your________________0.0000
 22. word_freq_credit______________0.0000
 23. word_freq_address_____________0.0000
 24. word_freq_report______________0.0000
 25. word_freq_addresses___________0.0000
 26. word_freq_all_________________0.0000
 27. word_freq_people______________0.0000
 28. word_freq_will________________0.0000
 29. word_freq_order_______________0.0000
 30. word_freq_internet____________0.0000
 31. word_freq_our_________________0.0000
 32. word_freq_3d__________________0.0000
 33. word_freq_receive_____________0.0000
 34. word_freq_lab_________________0.0000
 35. word_freq_money_______________0.0000
 36. word_freq_cs__________________0.0000
 37. char_freq_#___________________0.0000
 38. char_freq_[___________________0.0000
 39. char_freq_(___________________0.0000
 40. char_freq_;___________________0.0000
 41. word_freq_table_______________0.0000
 42. word_freq_re__________________0.0000
 43. word_freq_project_____________0.0000
 44. word_freq_original____________0.0000
 45. word_freq_meeting_____________0.0000
 46. word_freq_direct______________0.0000
 47. word_freq_650_________________0.0000
 48. word_freq_pm__________________0.0000
 49. word_freq_parts_______________0.0000
 50. word_freq_technology__________0.0000
 51. word_freq_85__________________0.0000
 52. word_freq_415_________________0.0000
 53. word_freq_data________________0.0000
 54. word_freq_857_________________0.0000
 55. word_freq_telnet______________0.0000
 56. word_freq_labs________________0.0000
 57. word_freq_make________________0.0000
print('\nRANDOM FOREST\n====================================')
display_feature_importance(forest)
RANDOM FOREST
====================================
  1. char_freq_$___________________0.1412
  2. char_freq_!___________________0.1224
  3. word_freq_remove______________0.1150
  4. word_freq_free________________0.0709
  5. capital_run_length_longest____0.0601
  6. word_freq_your________________0.0578
  7. capital_run_length_average____0.0521
  8. word_freq_money_______________0.0515
  9. capital_run_length_total______0.0507
 10. word_freq_george______________0.0482
 11. word_freq_000_________________0.0414
 12. word_freq_hp__________________0.0302
 13. word_freq_internet____________0.0232
 14. word_freq_hpl_________________0.0213
 15. word_freq_our_________________0.0199
 16. word_freq_you_________________0.0195
 17. word_freq_all_________________0.0141
 18. word_freq_1999________________0.0098
 19. word_freq_business____________0.0094
 20. word_freq_receive_____________0.0072
 21. word_freq_over________________0.0065
 22. word_freq_make________________0.0051
 23. word_freq_address_____________0.0045
 24. word_freq_edu_________________0.0028
 25. word_freq_will________________0.0026
 26. word_freq_order_______________0.0025
 27. word_freq_lab_________________0.0024
 28. word_freq_re__________________0.0016
 29. word_freq_addresses___________0.0016
 30. word_freq_credit______________0.0012
 31. word_freq_meeting_____________0.0011
 32. word_freq_labs________________0.0006
 33. char_freq_;___________________0.0005
 34. word_freq_415_________________0.0003
 35. char_freq_(___________________0.0002
 36. word_freq_conference__________0.0002
 37. word_freq_original____________0.0002
 38. word_freq_mail________________0.0001
 39. word_freq_pm__________________0.0001
 40. word_freq_technology__________0.0001
 41. char_freq_[___________________0.0000
 42. word_freq_3d__________________0.0000
 43. char_freq_#___________________0.0000
 44. word_freq_table_______________0.0000
 45. word_freq_85__________________0.0000
 46. word_freq_people______________0.0000
 47. word_freq_report______________0.0000
 48. word_freq_project_____________0.0000
 49. word_freq_font________________0.0000
 50. word_freq_cs__________________0.0000
 51. word_freq_direct______________0.0000
 52. word_freq_parts_______________0.0000
 53. word_freq_650_________________0.0000
 54. word_freq_telnet______________0.0000
 55. word_freq_857_________________0.0000
 56. word_freq_data________________0.0000
 57. word_freq_email_______________0.0000
print('\nBOOSTED TREE\n====================================')
display_feature_importance(boosted)
BOOSTED TREE
====================================
  1. char_freq_$___________________0.2472
  2. char_freq_!___________________0.2080
  3. word_freq_remove______________0.1500
  4. word_freq_free________________0.0722
  5. word_freq_hp__________________0.0679
  6. capital_run_length_average____0.0658
  7. capital_run_length_longest____0.0375
  8. word_freq_george______________0.0357
  9. word_freq_your________________0.0247
 10. word_freq_money_______________0.0200
 11. word_freq_our_________________0.0175
 12. capital_run_length_total______0.0100
 13. word_freq_edu_________________0.0094
 14. word_freq_650_________________0.0068
 15. word_freq_re__________________0.0041
 16. word_freq_meeting_____________0.0032
 17. word_freq_000_________________0.0032
 18. word_freq_1999________________0.0031
 19. word_freq_receive_____________0.0026
 20. word_freq_internet____________0.0020
 21. word_freq_you_________________0.0018
 22. word_freq_business____________0.0017
 23. word_freq_over________________0.0015
 24. char_freq_;___________________0.0012
 25. word_freq_3d__________________0.0008
 26. word_freq_font________________0.0008
 27. word_freq_project_____________0.0004
 28. word_freq_conference__________0.0004
 29. word_freq_report______________0.0003
 30. word_freq_will________________0.0002
 31. word_freq_addresses___________0.0000
 32. word_freq_people______________0.0000
 33. word_freq_email_______________0.0000
 34. word_freq_all_________________0.0000
 35. word_freq_address_____________0.0000
 36. word_freq_mail________________0.0000
 37. word_freq_order_______________0.0000
 38. word_freq_lab_________________0.0000
 39. word_freq_credit______________0.0000
 40. word_freq_pm__________________0.0000
 41. char_freq_#___________________0.0000
 42. char_freq_[___________________0.0000
 43. char_freq_(___________________0.0000
 44. word_freq_table_______________0.0000
 45. word_freq_original____________0.0000
 46. word_freq_cs__________________0.0000
 47. word_freq_direct______________0.0000
 48. word_freq_parts_______________0.0000
 49. word_freq_hpl_________________0.0000
 50. word_freq_technology__________0.0000
 51. word_freq_85__________________0.0000
 52. word_freq_415_________________0.0000
 53. word_freq_data________________0.0000
 54. word_freq_857_________________0.0000
 55. word_freq_telnet______________0.0000
 56. word_freq_labs________________0.0000
 57. word_freq_make________________0.0000

6.2.4. In class exercise#

The following dataset can be found at UCI ML repository

Based on census information, can we predict whether an individual makes over $50K/yr?

from ucimlrepo import fetch_ucirepo 
  
# fetch dataset 
adult = fetch_ucirepo(id=2) 
  
# data (as pandas dataframes) 
X = adult.data.features 
y = adult.data.targets 

X
age workclass fnlwgt education education-num marital-status occupation relationship race sex capital-gain capital-loss hours-per-week native-country
0 39 State-gov 77516 Bachelors 13 Never-married Adm-clerical Not-in-family White Male 2174 0 40 United-States
1 50 Self-emp-not-inc 83311 Bachelors 13 Married-civ-spouse Exec-managerial Husband White Male 0 0 13 United-States
2 38 Private 215646 HS-grad 9 Divorced Handlers-cleaners Not-in-family White Male 0 0 40 United-States
3 53 Private 234721 11th 7 Married-civ-spouse Handlers-cleaners Husband Black Male 0 0 40 United-States
4 28 Private 338409 Bachelors 13 Married-civ-spouse Prof-specialty Wife Black Female 0 0 40 Cuba
... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
48837 39 Private 215419 Bachelors 13 Divorced Prof-specialty Not-in-family White Female 0 0 36 United-States
48838 64 NaN 321403 HS-grad 9 Widowed NaN Other-relative Black Male 0 0 40 United-States
48839 38 Private 374983 Bachelors 13 Married-civ-spouse Prof-specialty Husband White Male 0 0 50 United-States
48840 44 Private 83891 Bachelors 13 Divorced Adm-clerical Own-child Asian-Pac-Islander Male 5455 0 40 United-States
48841 35 Self-emp-inc 182148 Bachelors 13 Married-civ-spouse Exec-managerial Husband White Male 0 0 60 United-States

48842 rows × 14 columns

y.replace({'<=50K.':'<=50K', '>50K.':'>50K'}, inplace = True)
/var/folders/qm/g7x838zs775f4j_5s231csf80000gn/T/ipykernel_11747/2185987382.py:1: SettingWithCopyWarning: 
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  y.replace({'<=50K.':'<=50K', '>50K.':'>50K'}, inplace = True)
X = X.drop(columns = 'education')
X.replace({np.nan:'?'}, inplace = True)
from sklearn.preprocessing import OrdinalEncoder, OneHotEncoder, StandardScaler
from sklearn.compose import ColumnTransformer

ord_features = ['sex']
oe = OrdinalEncoder(categories = [['Male', 'Female']])

cat_features = ['workclass', 'marital-status', 'occupation', 'relationship', 'race', 'native-country']
oh = OneHotEncoder()

ss = StandardScaler()

num_features = ['age', 'fnlwgt', 'education-num', 'capital-gain', 'capital-loss', 'hours-per-week']
ct = ColumnTransformer([
    ('ord', oe, ord_features),
    ('oh', oh, cat_features),
    ('ss', ss, num_features)
],
    sparse_threshold = 0,
    verbose_feature_names_out=False)

Xt = ct.fit_transform(X)
Xt.shape
(48842, 91)