Photo by Markus Winkler on Unsplash
Demystifying fit_transform and transform in Scikit-learn: Which Method to Use for Data Preprocessing and When?
Machine learning models often require data preprocessing before being trained. Preprocessing involves transforming the raw data into a format that is more suitable for machine learning algorithms. In scikit-learn, there are two methods used for data preprocessing: fit_transform
and transform
.
fit_transform
fit_transform
is a method that combines the fit()
and transform()
methods into a single step. It is commonly used to preprocess the training data and learn any necessary parameters, such as mean and standard deviation for scaling, on the training set. The learned parameters are then applied to both the training and testing data using the transform()
method. This ensures that the testing data is preprocessed in the same way as the training data.
Here's an example of using fit_transform
to preprocess a training set of data:
from sklearn.preprocessing import StandardScaler
# Create a StandardScaler object
scaler = StandardScaler()
# Apply fit_transform to the training data
X_train_scaled = scaler.fit_transform(X_train)
In this example, we create a StandardScaler
object called scaler
. We then apply the fit_transform
method to the training set X_train
, which scales the data and learns any necessary parameters (e.g., mean and standard deviation). The result of fit_transform
is stored in a new variable called X_train_scaled
.
transform
The transform
the method is used to apply the learned parameters to new data. This is typically done on testing data after preprocessing has been applied to the training data using fit_transform
. By applying the same preprocessing steps to both the training and testing data, we ensure that the testing data is processed in the same way as the training data.
Here's an example of using transform
to preprocess a testing set of data:
# Apply transform to the testing data
X_test_scaled = scaler.transform(X_test)
In this example, we apply the transform
method to the testing set X_test
, which applies the same scaling and parameter learning that was done on the training set. The result of transform
is stored in a new variable called X_test_scaled
.
Note: We used the same scaler
object to transform this test data on which we applied train data.
When to use fit_transform and transform
fit_transform
is typically used on the training set to learn any necessary parameters and preprocess the data. The same fit_transform
object should then be used to preprocess the testing set using the transform
method. This ensures that the testing data is processed in the same way as the training data and prevents data leakage.
This is also useful when your test data has few different features set than train data
Putting it together
Here's an example of how to use fit_transform
and transform
on a complete dataset:
# Import the dataset
from sklearn.datasets import load_iris
iris = load_iris()
# Split the dataset into training and testing sets
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(iris.data, iris.target, test_size=0.2)
# Create a StandardScaler object and apply it to the training data
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
# Apply the same scaler object to the testing data
X_test_scaled = scaler.transform(X_test)
In this example, we import the Iris dataset and split it into training and testing sets. We then create an StandardScaler
object and use fit_transform
it to preprocess the training data. Finally, we use the same scaler
object to preprocess the testing data using the transform
method.
Conclusion
In summary, transform
is used to transform your test or prediction data by utilizing the learning from the trained model transformed using fit_transform