Scatter Plot

Scatter Plot

Definition

Core Statement

A Scatter Plot displays the relationship between two continuous numerical variables. One variable is plotted on the x-axis and the other on the y-axis, with each point representing an observation. It is the primary tool for detecting correlation, clusters, and outliers.


Purpose

  1. Assess Relationship: Positive? Negative? None?
  2. Check Linearity: Is the relationship a straight line (for Linear Regression) or curved?
  3. Find Outliers: Points far from the main trend/cloud.
  4. Identify Clusters: Groups of points separated in space.

Patterns to Look For

Pattern Implication
Upward Slope Positive Correlation (as X increases, Y increases).
Downward Slope Negative Correlation (as X increases, Y decreases).
Circular Cloud No Correlation (r0).
U-Shape (Parabola) Non-Linear relationship. Correlation might be 0, but relationship is strong!
Fan Shape (Funnel) Heteroscedasticity (Variance changes with X).

Worked Example: Advertising vs Sales

Problem

You plot TV Ad Spend (X) vs Product Sales (Y).

  • Patterns:
    • Low Spend -> Low Sales (tight cluster).
    • High Spend -> High Sales (but very spread out).

Interpretation:

  1. Positive Correlation: Ads help sales.
  2. Diminishing Returns: The slope flattens at top right? (Check for curve).
  3. Heteroscedasticity: Prediction is reliable at low spend, but risky at high spend (high variance).

Assumptions


Limitations & Pitfalls

Pitfalls

  1. Overplotting: If you have 1,000,000 points, a scatter plot forms a solid blob.
    • Fix: Use Alpha blending (transparency), smaller dot size, or Hexbin plots.
  2. Correlation Causation: A perfect line does not mean X causes Y. (See Correlation vs Causation).
  3. The "Anscombe's Quartet" Trap: Four datasets can have identical correlation coeff (r=0.816) but look completely different (one curved, one with outlier, one normal). Always look at the plot, don't just trust the number.


Python Implementation

import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

# Data
x = np.random.rand(100)
y = 2*x + np.random.normal(0, 0.1, 100)

# 1. Basic Scatter
plt.scatter(x, y, alpha=0.5)
plt.title("Basic Scatter")
plt.xlabel("X Variable")
plt.ylabel("Y Variable")
plt.show()

# 2. Seaborn with Regression Line
sns.regplot(x=x, y=y, scatter_kws={'alpha':0.5}, line_kws={'color':'red'})
plt.title("Scatter with Regression Line")
plt.show()

# 3. Categorical Coloring (Multivariate)
sns.scatterplot(x='Age', y='Income', hue='Gender', data=df)
plt.title("Multivariate Scatter")
plt.show()