관리 메뉴

꿈꾸는 사람.

[Pandas Groupby] 타이타닉 생존율 분석 (그룹별 상위 N개 추출 패턴) 본문

Python

[Pandas Groupby] 타이타닉 생존율 분석 (그룹별 상위 N개 추출 패턴)

현무랑 니니 2026. 1. 7. 11:13
반응형

[문제] 타이타닉 선실 등급별 상위10명의 생존율을 구하라.

  • 상황: 승객들의 생존율을 분석하고 있다. 
  • 문제: 선실 등급(Pclass)별로 나이(Age)가 가장 많은 상위 10명을 뽑은 후, 그들의 평균 생존율(Survived)을 구하시오.
  • 핵심 포인트: 그룹별로 정렬 기준(나이)과 집계 대상(생존율)이 다른 경우이다.
  • 데이터셋: 타이타닉 생존자 (캐글 csv 파일 이용)

[풀이 요약]

본 문제의 핵심은 groupby를 두 번 쓰는 것이다.

보통 groupby는 한 번만 쓰지만 '데이터 필터링(추출)과 '데이터 집계(계산)'의 단계가 다르기 때문에 groupby를 두 번 사용해야 원하는 결과를 얻을 수 있다.

전체적인 로직은 타이타닉 데이터셋을 입력 받아 먼저 필터링하고, 줄어든 데이터를 다시 계산하는 2단계로 이루어진다.

"Sort -> Groupby(Filter) -> Groupby(Agg)"

[상세 풀이]

  • 3단계 데이터 변환
1. 줄 세우기 (sort_values)
     -> 원하는 순서대로 정렬 (선실 등급별과 나이별로 정렬)

2. 잘라내기 (groupby().head(10))
    -> 그룹별 상위/하위 N개 추출
        정렬된 전체 데이터프레임(891행)을 입력 받아 선실 등급별 상위 10명씩만 남은 데이터프레임 (30행)을 반환함

3. 계산하기 (groupby().mean/sum)
    -> 이전에 만든 30줄의 데이터프레임을 각 그룹별 생존율 평균을 계산하여 선시리즈(Series) (3행)을 반환함
  • Python 코드
import pandas as pd

df = pd.read_csv('titanic.csv', encoding='utf-8')

ans = df.sort_values(by=['Pclass', 'Age'], ascending=False) \
        .groupby(by='Pclass').head(10) \
        .groupby(by='Pclass')['Survived'].mean()
        
print(ans)

 

  • 데이터의 변환 과정(Flow)

  • Step 1 (정렬): 전체 데이터를 Pclass와 Age 기준으로 정렬한 상태이다.
# 1등급 -> 2등급 -> 3등급 순으로, 그리고 나이가 많은 순으로 정렬
step1 = df.sort_values(by=['Pclass', 'Age'], ascending=False)
  • 동작: 전체 데이터(891행)가 pclass와 age 기준으로 재배열된다.
  • Step 2 (추출): .head(10)으로 상위 데이터만 잘라낸다. 이때 결과물은 다시 평범한 표(DataFrame)가 된다. (여기가 핵심 포인트!)
# 각 등급별로 10행씩만 가져오기 (총 30행이 됨)
step2 = step1.groupby(by='Pclass').head(10)
  • 핵심 원리: 여기서 groupby는 평균을 구하기 위함이 아니라, head(10)을 그룹별로 적용하기 위한 용도이다.
  • 수행 결과:
    • 입력: 891행 (전체 승객)
    • 출력: 30행 (1등급 10명 + 2등급 10명 + 3등급 10명)
    • 중요: head()를 거치면서 그룹핑 상태는 해제되고, 다시 단순한 DataFrame(표) 상태가 된다. 즉, 컴퓨터는 이제 이 30명이 1등급인지 3등급인지 그룹으로 묶어두지 않은 상태이다.
  • Step 3 (집계): 평범한 표가 되었으므로, 평균을 구하기 위해 다시 한번 .groupby로 묶어줘야 결과(Series)가 나온다.
# 30명의 데이터를 다시 등급별로 묶어서 평균 계산
step3 = step2.groupby(by='Pclass')['Survived'].mean()
  • 핵심 원리: 이전의 풀려버린 그룹을 다시 묶어주는 역할이다.
  • 수행 결과:
    • 입력: 30행의 데이터프레임
    • 출력: 등급별 생존율이 담긴 Series (3행)
  • 최종 수행 결과 확인
  • 코드를 실행하면 다음과 같이 등급별 노령 승객(Top 10)의 생존율이 계산된다.
Pclass
1    0.7
2    0.3
3    0.3
Name: Survived, dtype: float64

[결론]

이 패턴을 암기하자!

  1. sort_values: 전체 줄 세우기
  2. groupby().head(N): (1차 그룹핑) 개수 자르기 -> DataFrame 반환
  3. groupby().mean(): (2차 그룹핑) 통계 내기 -> Series 반환

이 3단계 흐름, 특히 "자르고 나면 그룹이 풀리니까 다시 묶는다"는 점만 기억하면 어떤 복잡한 문제도 해결할 수 있다.

반응형
Comments