This is a simple example of Linear Discriminant Analysis (LDA) using Python and the scikit-learn library.
Linear Discriminant Analysis (LDA) is a dimensionality reduction and classification technique that finds the linear combinations of features that best separate two or more classes. It aims to maximize the distance between the means of different classes while minimizing the spread (variance) within each class. LDA is commonly used in classification tasks and is particularly effective when the classes are well-separated.
Key concepts of LDA:
LDA is often applied in combination with other classification algorithms for improved performance.
Python Source Code:
# Import necessary libraries
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import make_classification
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
# Generate synthetic data with two classes
X, y = make_classification(n_samples=200, n_features=2, n_informative=2, n_redundant=0,
n_clusters_per_class=1, random_state=42)
# Split the data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# Apply Linear Discriminant Analysis (LDA) for dimensionality reduction
lda = LinearDiscriminantAnalysis(n_components=1)
X_train_lda = lda.fit_transform(X_train, y_train)
X_test_lda = lda.transform(X_test)
# Train a classifier (e.g., Logistic Regression) on the reduced features
from sklearn.linear_model import LogisticRegression
classifier = LogisticRegression(random_state=42)
classifier.fit(X_train_lda, y_train)
# Predict on the test set
y_pred = classifier.predict(X_test_lda)
# Calculate accuracy
accuracy = accuracy_score(y_test, y_pred)
# Plot the decision boundary and data points
plt.figure(figsize=(10, 6))
plt.scatter(X_test_lda, np.zeros_like(X_test_lda), c=y_test, cmap='viridis', marker='o', edgecolors='k')
plt.title('Linear Discriminant Analysis (LDA) Decision Boundary')
plt.xlabel('LDA Component')
plt.ylabel('Dummy Axis')
plt.axhline(y=0, color='black', linestyle='--', linewidth=2, label='Decision Boundary')
plt.legend()
plt.show()
# Display accuracy
print(f'Accuracy: {accuracy:.2f}')
Explanation: