A Modelling Decision Tree is a predictive model that, as its name implies, can be viewed as a tree. Decision trees are commonly used in operations research, specifically in decision analysis, to help identify a strategy most likely to reach a goal. It is also commonly used in Machine Learning.
In this Blog we will discuss about:
- How to create a decision tree for the sales of child car seats. And also perform validation and pruning on Decision Tree using complex Parameter.
- Use Rstudio to plot the tree.
- To Create a Decision Tree in R, we need to make use of the functions rpart(), or tree(), party(), etc.
The format of the rpart() command works similarly to the aggregate function. rpart() syntax takes ‘dependent attribute’ and the rest of the attributes are independent in the analysis.
kyphosis: Dependent attribute, As Survived depends on the factors Age, Number, Start.
Age, Number, Start are Independent Attributes.
If you wanted to predict a continuous variable, such as age, you may use method=”anova”. This would run generate decimal quantities for you. But here, we just want a one or a zero, so method=”class” is appropriate for it.
Plot Decision Tree
Let’s examine the tree. There are a lot of ways to do this, and the built-in version requires running.
Use rattle package to plot the tree:
rpart.plot() and RcolorBrewer() functions help us to create a beautiful plot.
- ‘rpart.plot()’ plots rpart models. It extends plot.rpart and text.rpart in the rpart package.
- RcolorBrewer() provides us with beautiful color palettes and graphics for the plots.
Let’s try rendering this tree a bit nicer with fancyRpartPlot (of course).
Okay, now we’ve got somewhere readable. The decisions that have been found go a lot deeper than what we saw last time when we looked for them manually. This was a simple and efficient way to create a Decision Tree in R. But are you sure that this is the optimal ‘Decision Tree’ for this data? If not, the following validation checks will help you.
Validation of Decision Tree using the ‘Complexity Parameter’ and cross validated error:
Now to validate the model we use the printcp and plotcp functions. ‘CP’ stands for Complexity Parameter of the tree. Basically this function provides the optimal prunings based on the cp value. To find out how the tree performs, is calculated by the printcp() function.
From the above mentioned list of cp values, we can select the one having the least cross-validated error and use it to prune the tree. Always remind that the value of cp should be least, so that the cross-validated error rate is minimum. To select this, you can make use of this:
This function returns the optimal cp value associated with the minimum error.
Let use plotcp() function, Plotcp() provides a graphical representation to the cross validated error summary. The cp values are plotted against the geometric mean to depict the deviation until the minimum value is reached. Let see what Plotcp() function fetches.
Prune the tree to create an optimal Decision Tree:
Thus we create a pruned decision tree.