Published on
👁️

FastAPI 마이그레이션 - 2. FastAPI 앱 구현 및 배포

Authors
  • avatar
    Name
    River
    Twitter

이전 페이지로 이동 (1. 프로젝트 기반 구축)

FastAPI 마이그레이션
(Spring Boot에서 FastAPI로 전환)

2. FastAPI 앱 구현 및 배포
2.1. 최소 FastAPI 앱 생성

app/main.py

from fastapi import FastAPI

app = FastAPI(
    title="Holiday Keeper API",
    description="공휴일 정보 관리 API",
    version="1.0.0",
    docs_url="/docs",
    redoc_url="/redoc"
)

@app.get("/")
async def root():
    return {"message": "Holiday Keeper API is running!"}

@app.get("/health")
async def health_check():
    return {
        "status": "healthy",
        "service": "holiday-keeper-api",
        "version": "1.0.0"
    }

@app.get("/test")
async def test():
    return {
        "message": "테스트 성공",
        "data": ["공휴일1", "공휴일2", "공휴일3"]
    }
2.2. 환경설정 및 기본 구조

Spring Boot (자동 구성)

  • application.yml

    server:
    port: 8090
    spring:
    datasource:
        url: jdbc:postgresql://localhost:5432/holiday_db
    
    • 설정 값 저장 (DB 연결 정보, 포트 등)
    • Spring Boot의 핵심 중 하나인 Auto Configuration으로 자동 인식


  • WebConfig.java

    @Configuration
    public class WebConfig implements WebMvcConfigurer {
        
        @Override
        public void addCorsMappings(CorsRegistry registry) {
            registry.addMapping("/**")
                .allowedOrigins("http://localhost:3000")
                .allowedMethods("*");
        }
        
        @Override
        public void addInterceptors(InterceptorRegistry registry) {
            registry.addInterceptor(new LoggingInterceptor());
        }
    }
    
    • 설정 값 적용, CORS 직접 설정, 인터셉터 등록
    • @Configuration으로 자동 스캔


  • LoggingInterceptor.java

    @Component
    public class LoggingInterceptor implements HandlerInterceptor {
        @Override
        public boolean preHandle(HttpServletRequest request, 
                            HttpServletResponse response, Object handler) {
            log.info("Request: {} {}", request.getMethod(), request.getRequestURI());
            return true;
        }
    }
    
    • 실제 로직 구현



FastAPI (수동 제어)

  • .env

    SERVER_HOST="0.0.0.0"
    SERVER_PORT=8090
    DATABASE_URL="postgresql://localhost:5432/holiday_db"
    CORS_ORIGINS="http://localhost:3000,http://localhost:5173"
    
    • 설정 값 저장 (DB 연결 정보, 포트 등)


  • config.py

    from pydantic_settings import BaseSettings
    
    class Settings(BaseSettings):
        server_host: str = "0.0.0.0"
        server_port: int = 8090
        database_url: str = "postgresql://..."
        cors_origins: str = "http://localhost:3000"
        
        model_config = {"env_file": ".env"}
        
        @property
        def cors_origins_list(self) -> List[str]:
            return self.cors_origins.split(",")
    
    settings = Settings()
    
    • .env 파일 읽기
    • 설정 객체 생성


  • main.py

    from fastapi import FastAPI
    from app.core.config import settings
    from app.core.middleware import add_middleware
    
    app = FastAPI(title="Holiday Keeper API")
    
    add_middleware(app)
    
    @app.get("/")
    async def root():
        return {"message": "Hello"}
    
    • 수동으로 모든 설정 연결


  • middleware.py

    from fastapi.middleware.cors import CORSMiddleware
    from app.core.config import settings
    
    def add_middleware(app: FastAPI):
        # CORS 설정 (WebConfig 역할)
        app.add_middleware(
            CORSMiddleware,
            allow_origins=settings.cors_origins_list,
            allow_credentials=True,
            allow_methods=["*"]
        )
        
        # 로깅 미들웨어 (Interceptor 역할)
        @app.middleware("http")
        async def log_requests(request, call_next):
            start_time = time.time()
            response = await call_next(request)
            process_time = time.time() - start_time
            logger.info(f"{request.method} {request.url} - {process_time:.3f}s")
            return response
    
    • Spring Boot의 WebConfig + Interceptor 구현을 한 파일에 합친 형태


  • deps.py

    from fastapi import Depends
    from app.core.config import settings
    
    async def get_db():
        # DB 세션 의존성 주입용
        pass
    
    def get_current_user():
        # 인증 의존성 주입용  
        pass
    
    • Spring Boot의 @Autowired 의존성 주입을 수동으로 구현하는 곳

app/core/config.py

from functools import lru_cache
from typing import List
import os
from pathlib import Path

from pydantic import Field
from pydantic_settings import BaseSettings

environment = os.getenv("ENVIRONMENT", "dev")
env_file = f".env.{environment}"

_secrets_dir = Path("/run/secrets") if Path("/run/secrets").exists() else Path(os.getenv("SECRETS_PATH", "./secrets"))

class Settings(BaseSettings):
    project_name: str = Field(default="Holiday-keeper-fastapi")
    version: str = Field(default="1.0.0")
    environment: str = environment

    server_host: str = Field(default="0.0.0.0")
    server_port: int = Field(default=8090)
    cors_origins: str = Field(default="http://localhost:3000,http://localhost:5173")
    log_level: str = Field(default="INFO")

    redis_url: str = Field(default="redis://localhost:6379")
    nager_api_base_url: str = Field(default="https://date.nager.at/api/v3")

    database_url: str
    sync_database_url: str

    model_config = {
        "env_file": env_file,
        "env_file_encoding": "utf-8",
        "case_sensitive": False,
        "extra": "ignore",
        "secrets_dir": _secrets_dir,
        # "validate_default": True,
        # "use_enum_values": True
    }

    @property
    def cors_origins_list(self) -> List[str]:
        return [origin.strip() for origin in self.cors_origins.split(",")]

    @property
    def is_dev(self) -> bool:
        return self.environment.lower() == "dev"

@lru_cache()
def get_settings() -> Settings:
    return Settings()

settings = get_settings()



@lru_cache와 싱글톤 패턴

@lru_cache()
def get_settings():
    return Settings()

settings = get_settings()
  • Spring의 경우 기본적으로 싱글톤이지만 FastAPI는 그렇지 않다. 객체를 Spring의 Bean 처럼 활용하기 위함이다


  • get_settings( )

    • 단순한 팩토리 함수 (객체 생성 반환)


  • @lru_cache

    • 함수 결과 캐싱싱글톤 구현
    • 첫번째 호출 시에만 Settings() 객체 생성
    • 함수를 여러 번 호출해도 캐싱한 결과 반환
    • 여기까지가 Spring Boot의 싱글톤 빈을 수동으로 구현한 것이다.


  • settings = get_settings( )

    • 전역 객체 생성
      • 모듈 로드 시 한 번만 실행
    • 다른 파일에서 from app.core.config import settings 시 같은 객체를 사용한다.
    • 이렇게 편의 객체를 만들지 않으면 매번 다른 파일에서 함수를 호출해야 한다



Pydantic의 Field

from pydantic import Field

class Settings(BaseSettings):

database_url: str

    server_port: int = Field(default=8090, ge=1024, le=65535, description="서버 포트")
  • Field 함수를 사용하지 않는 경우

    • 예시 : database_url: str
    • 단순히 타입 힌트를 제공하며, 필드가 필수값임을 선언하는 것
    • BaseSettings는 이 필드 이름과 일치하는 것을 찾아 로딩한다.
      1. 환경 변수 (OS 환경변수)
      2. .env 파일 값
      3. model_config에 설정된 secrets_dir 내의 동일한 파일 이름
        • 예시 : database_url.txt

  • Field 함수를 사용하는 경우

    • 단순히 타입을 선언하는 것이 아닌 기본값을 지정하거나 유효성 검사 등 설정 가능

    • Spring Boot의 @Value + Bean Validation을 합친 기능

    • 필드의 이름과 환경 변수의 이름이 같아야 동작한다
    • default

      • 기본값 설정
      • default 우선 순위
        1. 환경변수 (OS 환경변수)
        2. .env 파일 값
        3. Field(default=값)
    • ge/le

      • 범위 제한 (greater equal, less equal)

    • description

      • API 문서에 표시될 설명

    • alias

      # .env  
      MY_CUSTOM_PORT=8080
      
      # Python
      server_port: int = Field(default=8090, alias="MY_CUSTOM_PORT")
      
      • 다른 이름으로 매핑해야 하는 경우 사용



BaseSettings

from pydantic_settings import BaseSettings

environment = os.getenv("ENVIRONMENT", "dev")
env_file = f".env.{environment}"

_secrets_dir = Path("/run/secrets") if Path("/run/secrets").exists() else Path(os.getenv("SECRETS_PATH", "./secrets"))

class Settings(BaseSettings):
    name: str
    version: str = Field(default="1.0.0")
    
        model_config = {
            "env_file": env_file,            # 읽을 파일 경로
            "env_file_encoding": "utf-8",    # 파일 인코딩
            "case_sensitive": False,         # 대소문자 구분 안 함
            "extra": "ignore",               # 추가 필드 처리 방식
            "validate_default": True,        # 기본값도 검증할지
            "use_enum_values": True,         # Enum 값 사용
            "secrets_dir": _secrets_dir,     # Secret 폴더 경로 지정
        }
  • BaseSettings

    • Spring Boot의 @ConfigurationProperties
    • .env 파일 자동 읽기
    • 환경변수 자동 매핑
    • 타입 변환 (문자열 → 정수, 불린 등)
    • 유효성 검증


  • 프로파일 동적 설정

    environment = os.getenv("ENVIRONMENT", "dev")
    env_file = f".env.{environment}"
    
    class Settings(BaseSettings):
        model_config = {
                    "env_file": env_file,
        }
    
    • 환경 설정 파일
      • .env.dev
      • .env.prod
      • .env.test
    • 환경 변수 설정에 따라 동적으로 사용하는 설정이다.
    • 시크릿 설정은 시크릿 파일로 관리한다.


  • _secrets_dir 전역 변수

    • Docker Secrets 환경의 경로/run/secrets를 먼저 확인하고 없다면, 로컬 환경을 확인한다.
    • 전역 변수로 설정하는 이유는 Settings 클래스 정의 시 그 값을 활용하기 위함이다


  • model_config = { }

    • BaseSettings가 어떻게 환경변수를 읽을지 설정하는 메타데이터 설정

    • Settings 클래스가 정의될 때, model_config 내부의 값들도 함께 정해진다

    • Settings 클래스의 인스턴스가 생기기 전에는 클래스 내부 필드의 값을 참조할 수 없다
      • case_sensitive

        • Spring Boot는 kebab-case를 camelCase로 자동 변환
        • 환경변수 이름의 대소문자를 구분하는지 여부

      • extra: "ignore"

        • Spring BootignoreUnknownFields = true
        • Settings 클래스에 없는 필드에 대한 처리 여부
        • extra: "ignore"
          • Settings 클래스에 없는 환경변수 무시
          • extra: "allow"일 때만 settings.UNKNOWN_FIELD로 접근 가능

      • validate_default

        • 기본값이 잘못된 경우를 검증하는지 여부

      • use_enum_values

        • Settings 클래스에서 다른 enum 클래스를 필드 타입으로 사용할지 여부
          class Settings(BaseSettings):
                  environment: Environment = Environment.DEV
          
          class Environment(Enum):
              DEV = "development"
              PROD = "production"
          
        • use_enum_values=True"development" 값 사용

        • use_enum_values=FalseEnvironment.DEV 객체 사용


      • secrets_dir

        • Secret 폴더의 경로를 지정하여 해당 경로의 값들도 읽을 수 있게 설정한다


  • Pydantic v1 (구버전)

    class Settings(BaseSettings):
        name: str
        
        class Config:               # 내부 클래스
            env_file = ".env"
            case_sensitive = True
    



@property

class Settings:
    cors_origins: str = "http://localhost:3000,http://localhost:5173"
    
    @property
    def cors_origins_list(self) -> List[str]:
        return self.cors_origins.split(",")
        

# 사용법
settings.cors_origins_list
  • 함수를 속성처럼 접근 가능하게 한다
  • Java의 Getter와 같지만 호출 시 괄호가 필요 없다
  • @{property이름}.setter
    class Settings:
        def __init__(self):
            self._cors_data = "http://localhost:3000"
        
        @property
        def cors_origins(self):  # property 이름
            return self._cors_data
        
        @cors_origins.setter
        def cors_origins(self, value):
            self._cors_data = value
    
    • 이렇게 setter를 지정할 수 있다.
    • 단, 메서드 이름도 property와 동일해야 한다.
    • 환경변수 설정이므로 Settings 클래스에서는 쓸 일이 없다.



MySQL 연결 설정

# secrets/database_url
mysql+aiomysql://root:1234@127.0.0.1:3306/hk_db?charset=utf8mb4

# secrets/sync_database_url
mysql+pymysql://root:1234@127.0.0.1:3306/hk_db?charset=utf8mb4

# .env.dev
DATABASE_URL="mysql+aiomysql://root:1234@127.0.0.1:3306/hk_db?charset=utf8mb4"
SYNC_DATABASE_URL="mysql+pymysql://root:1234@127.0.0.1:3306/hk_db?charset=utf8mb4"


# config.py
class Settings:
    database_url: str
    sync_database_url: str
  • 개발 / 운영 DB 설정

    • 개발 환경
      • 변수를 읽어들이는 순서에 따라 먼저 .env.dev에사 읽어들인다.
      • .env.dev에는 값이 설정되어 있으므로 해당 값을 읽는다.
    • 운영 환경
      • 변수를 읽어들이는 순서에 따라 먼저 .env.prod에사 읽어들인다.
      • .env.prod에는 값이 설정되어 있지 않다.
      • 따라서, 다음 순서인 시크릿 파일을 읽어서 값을 읽는다.


  • FastAPI DB 연결 설정

    • diriver, username, password를 url 경로에 포함시켜 작성
    • charset=utf8mb4
      • 한글 처리 및 이모지까지 포함하는 설정
    • autocommit=true
      • connection 레벨에서, 즉 MySQL 드랑버 레벨에서 autocommit 처리
      • FastAPI는 SQLAlchemy가 트랜잭션 관리하므로 불필요
    • use_unicode=1
      • MySQL 드라이버가 유니코드 문자열을 제대로 처리하도록 하는 설정
      • 최신 Python MySQL 드라이버에서는 기본값
    • rewriteBatchedStatements=true
      • SQLAlchemy가 자체적으로 배치 처리
    • serverTimezone=Asia/Seoul
      • 대부분 자동 처리하지만 명시적으로 설정 가능


  • 기존 Spring Boot 설정

    spring:
    datasource:
        url: jdbc:mysql://127.0.0.1:3306/hk_db?characterEncoding=utf8&autoReconnect=true&serverTimezone=Asia/Seoul&rewriteBatchedStatements=true
        username: root
        password: 1234
        driver-class-name: com.mysql.cj.jdbc.Driver
    

app/core/middleware.py

from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.trustedhost import TrustedHostMiddleware
import time
import logging
from app.core.config import settings

logger = logging.getLogger(__name__)

def add_middleware(app: FastAPI):
    app.add_middleware(
        CORSMiddleware,
        allow_origins=settings.cors_origins_list,
        allow_credentials=True,
        allow_methods=["*"],
        allow_headers=["*"]
    )

    if not settings.is_dev:
        app.add_middleware(
            TrustedHostMiddleware,
            allowed_hosts=["yourdomain.com", "*.yourdomain.com"]
        )

    if settings.is_dev:
        @app.middleware("http")
        async def log_requests(request: Request, call_next):
            start_time = time.time()
            response = await call_next(request)
            process_time = time.time() - start_time

            logger.info(
                f"{request.method} {request.url.path} - "
                f"Status: {response.status_code} - "
                f"Time: {process_time:.4f}s"
            )
            return response



중요 설정

  • app.add_middleware( )

    from fastapi import FastAPI, Request
    
    def add_middleware(app: FastAPI):
    
        app.add_middleware(
            TrustedHostMiddleware,
            allowed_hosts=["yourdomain.com", "*.yourdomain.com"]
        )
    
    • add_middleware()
      • FastAPI에 내장된 함수


  • 미들웨어 개념

    • Spring Boot의 Filter/Interceptor
      @Component
      public class LoggingFilter implements Filter {
          @Override
          public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) {
              // 요청 전 처리
              chain.doFilter(request, response);  // 다음 필터/컨트롤러로
              // 응답 후 처리
          }
      }
      

    • FastAPI의 미들웨어
      @app.middleware("http")
      async def logging_middleware(request, call_next):
          # 요청 전 처리
          response = await call_next(request)  # 다음 미들웨어/라우터로
          # 응답 후 처리
          return response
      
      • 미들웨어는 요청/응답을 가로채서 처리하는 중간 계층
      • 인증, 로깅, CORS, 압축 등


  • TrustedHostMiddleware

    from fastapi.middleware.trustedhost import TrustedHostMiddleware
    
    def add_middleware(app: FastAPI):
    
        app.add_middleware(
            TrustedHostMiddleware,
            allowed_hosts=["yourdomain.com", "*.yourdomain.com"]
        )
    
    • 악의적인 Host 헤더로 인한 캐시 중독, 리다이렉트 공격을 방지
    • Spring Security에서 Host Header 공격 방지와 비슷한 개념

      @Bean
      public SecurityFilterChain filterChain(HttpSecurity http) {
          http.headers().frameOptions().sameOrigin();
      }
      


  • logger

    import logging
    
    logger = logging.getLogger(name)
    
    • **logging.getLogger(__**name**__)**
      • 로거 인스턴스 생성
      • Spring의 LoggerFactory 같은 것
      • __name__
        • 현재 모듈명 ("app.core.middleware")
        • 모듈별로 구분된 로거를 만들어 로그 출처 식별 가능


  • @app.middleware("http")

    from fastapi import FastAPI, Request
    import time
    import logging
    
    def add_middleware(app: FastAPI):
    @app.middleware("http")
    async def log_requests(request: Request, call_next):
        start_time = time.time()
        response = await call_next(request)
        process_time = time.time() - start_time
    
        logger.info(
            f"{request.method} {request.url.path} - "
            f"Status: {response.status_code} - "
            f"Time: {process_time:.4f}s"
        )
        return response
    
    • @app.middleware("http")

      • 커스텀 HTTP 미들웨어를 등록하는 데코레이터
      • Spring Boot의 HandlerInterceptor를 함수로 구현한 것

    • Request

      • Http Request 객체
      • Spring Boot의 HttpServletRequest와 동일한 역할

    • call_next

      async def middleware(request: Request, call_next):
          # 전처리
          response = await call_next(request)  # 다음 미들웨어/라우터로 전달
          # 후처리
          return response
      
      • call_next다음 미들웨어나 실제 라우터 함수를 호출하는 함수

      • 실행 순서

        1. 미들웨어 A 전처리
        2. call_next() → 미들웨어 B 전처리
        3. call_next() → 실제 라우터 함수 실행
        4. 미들웨어 B 후처리
        5. 미들웨어 A 후처리
      • Spring Boot의 FilterChain.doFilter()와 동일한 개념

        public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) {
            // 전처리
            chain.doFilter(request, response);  // 다음 필터/컨트롤러로 전달
            // 후처리
        }
        




app/main.py

from fastapi import FastAPI

from app.core.config import settings
from app.core.middleware import add_middleware

app = FastAPI(
    title="Holiday Keeper API",
    description="공휴일 정보 관리 API",
    version="1.0.0",
    docs_url="/docs",
    redoc_url="/redoc"
)

add_middleware(app)

@app.get("/")
async def root():
    return {"message": "Holiday Keeper API is running!"}

@app.get("/health")
async def health_check():
    return {
        "status": "healthy",
        "service": "holiday-keeper-api",
        "version": "1.0.0"
    }

@app.get("/test")
async def test():
    return {
        "message": "테스트 성공",
        "data": ["공휴일1", "공휴일2", "공휴일3"]
    }

@app.get("/config")
async def get_config():
    if not settings.is_dev:
        return {"error": "개발 환경에서만 사용 가능"}

    return {
        "project_name": settings.project_name,
        "version": settings.version,
        "environment": settings.environment,
        "server_host": settings.server_host,
        "server_port": settings.server_port,
        "cors_origins": settings.cors_origins_list,
        "log_level": settings.log_level
    }

  • middleware.pyadd_middleware(app)를 활용하여 관심사 분리


  • middleware.py 도입 전

    from fastapi import FastAPI
    from fastapi.middleware.cors import CORSMiddleware
    
    app = FastAPI(
        title="Holiday Keeper API",
        description="공휴일 정보 관리 API",
        version="1.0.0",
        docs_url="/docs",
        redoc_url="/redoc"
    )
    
    app.add_middleware(
        CORSMiddleware,
        allow_origins=settings.cors_origins_list,
        allow_credentials=True,
        allow_methods=["*"],
        allow_headers=["*"]
    )
    
2.3. DB 연결 설정 및 간단한 API 구현하기

app/core/database.py

from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from app.core.config import settings
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker

sync_engine = create_engine(
    settings.sync_database_url,
    pool_pre_ping=True,
    echo=settings.is_dev
)

async_engine = create_async_engine(
    settings.database_url,
    pool_pre_ping=True,
    echo=settings.is_dev
)

SessionLocal = sessionmaker(
    bind=sync_engine,
    autocommit=False,
    autoflush=False,
)

AsyncSessionLocal = async_sessionmaker(
    bind=async_engine,
    autocommit=False,
    autoflush=False,
    expire_on_commit=False
)

async def get_db():
    async with AsyncSessionLocal() as db:
        yield db

def get_sync_db():
    sync_db = SessionLocal()
    try:
        yield sync_db
    finally:
        sync_db.close()



DB Engine

from sqlalchemy import create_engine
from sqlalchemy.ext.asyncio import create_async_engine
from app.core.config import settings

sync_engine = create_engine(
    settings.sync_database_url,
    pool_pre_ping=True,
    echo=settings.is_dev
)

async_engine = create_async_engine(
    settings.database_url,
    pool_pre_ping=True,
    echo=settings.is_dev
)
  • Engine

    • 데이터베이스 Connection Pool을 관리하는 객체
    • 데이터베이스 연결 생성 및 연결 재사용
    • 엔진은 전역에서 하나만 생성해서 공유하는 것이 권장된다
      • 엔진마다 별도의 연결 풀을 생성하면 낭비이다.
      • 또 엔진 생성에는 비용이 들기 때문에 좋지 않다.

  • create_engine 함수 (동기 엔진)

    • 동기 방식으로 엔진 생성
    • 블로킹 방식
      • 각 쿼리가 완료될 때까지 대기


  • create_async_engine 함수 (비동기 엔진)

    • 비동기 방식으로 엔진 생성
    • 논블로킹 방식
      • 쿼리 실행 중에 다른 작업 수행 가능


  • engine 파라미터

    • url: str | URL
      • 필수 인자
      • 데이터베에스 연결 URL
    • pool_pre_ping=True
      • connection pool에서 connection을 가져올 때 먼저 ping을 보내 connection 상태 확인
    • echo=settings.is_dev
      • 개발 모드일 때 실행되는 SQL 쿼리를 콘솔에 출력
      • settings.is_dev는 임의로 만든 환경 확인 함수
    • pool_size
      • 연결 풀의 기본 크기
      • default : 5
    • max_overflow
      • 풀이 가득 찰 때 추가로 생성할 수 있는 연결 수
      • default : 10
    • pool_timeout
      • 연결을 얻기 위해 대기할 최대 시간(초)
    • pool_recycle
      • 연결의 최대 생존 시간(초)
      • 여기서 정해진 시간이 지남에 따라 연결을 폐기 후 새로 생성한다.



DB 세션 생성

from sqlalchemy.orm import sessionmaker
from sqlalchemy.ext.asyncio import async_sessionmaker

SessionLocal = sessionmaker(
    bind=sync_engine,        # 사용할 엔진
    autocommit=False,        # 자동 커밋 비활성화
    autoflush=False,         # 자동 플러시 비활성화
)

AsyncSessionLocal = async_sessionmaker(
    bind=async_engine,       # 비동기 엔진 바인딩
    autocommit=False,        # 자동 커밋 비활성화
    autoflush=False,         # 자동 플러시 비활성화
    expire_on_commit=False   # 커밋 후 객체 만료 방지
)
  • SessionLocal

    • 세션 팩토리 객체
    • SessionLocal()을 호출할 때마다 새로운 세션 인스턴스 생성
    • 각 요청/트랜잭션마다 독립적인 세션 사용

  • sessionmaker

    • 세션을 생성하는 팩토리 함수
    • sessionmaker - (동기)
    • async_sessionmaker - (비동기)
    • 호출할 때마다 새로운 세션 인스턴스 생성


  • session 파라미터

    • bind

      • 필수 인자
      • 연결할 엔진 지정
    • autocommit=False

      • ORM 세션 레벨에서의 autocommit 처리
      • 명시적으로 commit()을 호출해야 변경사항이 저장된다
      • 반드시 false로 해야 트랜잭션 관리를 할 수 있다
      • true인 경우, 즉시 DB에 저장되어 롤백 불가능!
    • autoflush=False

      • true : 기본값
      • False로 설정해야 ****쿼리 전에 자동으로 flush하지 않는다
      • 즉, 수동으로 제어하는 것으로 성능상 유리하다.
    • expire_on_commit

      • true : 기본값
      • False로 설정해야 ****커밋 후에도 객체의 속성에 접근 가능
      • expire_on_commit=False비동기에서 중요
        • 세션이 닫힌 후에도 해당 객체를 활용하기 위해서
    • class_

      • 사용할 세션 클래스 지정
      • 생략 시 기본값으로 지정된다
        • sessionmakerSession
        • async_sessionmakerAsyncSession
      • 커스템 세션 클래스의 경우 명시적으로 지정
    • info

      SessionLocal = sessionmaker(
          bind=engine,
          info={
              'app_name': 'my_app',
              'version': '1.0.0',
              'debug_mode': True
          }
      )
      
      # 저장 정보 사용 방법
      session.info['key']
      
      • 추가 세션 메타데이터 딕셔너리



DB 세션 생성

async def get_db():
    async with AsyncSessionLocal() as db:
        yield db

def get_sync_db():
    sync_db = SessionLocal()
    try:
        yield sync_db
    finally:
        sync_db.close()
  • async with 문법

    async with AsyncSessionLocal() as session:
        # 세션 사용
        pass  # 자동으로 session.close() 호출됨
    
    • async with비동기 컨텍스트 매니저
      • 즉, 세션의 생명주기 관리세션 생성/정리를 자동으로 처리
      • 진입 시 : AsyncSessionLocal()로 세션 생성
      • 종료 시 : 자동으로 await session.close() 호출
      • 예외 발생 시 : 자동으로 정리 작업 수행
    • 일반 with와 비슷하지만 __aenter____aexit__ 메서드가 비동기


  • try, yield, finally 문법

    async def get_db():
        async with AsyncSessionLocal() as session:
            try:
                yield session          # 세션을 제공
            finally:
                await session.close()  # 정리 작업
    
    • yield
      • 이 함수를 제너레이터로 만든다
      • FastAPI가 의존성 주입할 때 필요한 문법
    • try
      • 세션 사용 중 예외가 발생할 수 있는 구간
    • finally
      • 예외 발생 여부와 상관없이 반드시 실행되는 정리 코드

    • FastAPI에서는 의존성 주입 시 이 패턴을 사용한다
      1. 요청 시작 시 세션 생성
      2. 요청 처리 중 세션 사용
      3. 요청 완료 후 세션 정리

app/common/models/base.py (Base 모델)

from sqlalchemy import Column, Integer, DateTime, func
from sqlalchemy.orm import DeclarativeBase

class Base(DeclarativeBase):
    pass

class TimestampMixin:
    created_at = Column(DateTime, server_default=func.now(), nullable=False)
    updated_at = Column(DateTime, server_default=func.now(), onupdate=func.now(), nullable=False)

class BaseModel(Base, TimestampMixin):
    __abstract__ = True

    id = Column(Integer, primary_key=True)
  • DeclarativeBase

    • SQLAlchemy 2.0에서 도입된 새로운 베이스 클래스
      • 메타클래스 자동 설정
        • DeclarativeMeta 메타클래스를 자동으로 적용
        • 복잡한 설정을 자동으로 적용
      • 타입 힌팅 개선
        • Python의 타입 시스템과 더 잘 통합 (IDE에서 자동완성 개선)
        • 타입 검사 지원
      • 레지스트리 관리
        • SQLAlchemy가 모든 테이블 정의를 중앙에서 관리
        • registry : 매핑된 클래스들의 레지스트리
      • 스키마 생성
        • 실제 데이터베이스에 테이블을 만들거나 삭제할 때 사용
          • Base.metadata.create_all(engine)
        • metadata 객체를 통해 테이블 생성/삭제 관리

    • 이전 방식

      # SQLAlchemy 1.x
      from sqlalchemy.ext.declarative import declarative_base
      
      Base = declarative_base()
      
      # SQLAlchemy 2.0+
      from sqlalchemy.orm import DeclarativeBase
      
      class Base(DeclarativeBase):
          pass
      


  • 네이밍 관례

    • Base

      • SQLAlchemy의 루트 베이스 클래스
      • 모든 모델의 최상위 부모
      • 보통 Base 또는 DeclarativeBase로 명명

    • BaseModel

      • 공통 필드와 로직을 담는 추상 베이스 모델
      • Base상속받아 확장한다

    • TimestampMixin

      • 믹스인 패턴 : 특정 기능을 여러 클래스에 주입
      • 보통 created_at, updated_at 같은 공통 시간 필드 제공
      • 이름에 Mixin 접미사로 목적 명시


  • __abstract__ = True

    • __abstract__의 의미
      • 이 클래스는 실제 테이블을 생성하지 않는다고 설정

      • 다른 모델들이 상속받기 위한 공통 속성 정의 역할
      • Spring Boot의 BaseEntity와 유사
      • 예시

        class TimestampMixin:
            created_at = Column(DateTime, server_default=func.now(), nullable=False)
            updated_at = Column(DateTime, server_default=func.now(), onupdate=func.now(), nullable=False)
        
        class BaseModel(Base, TimestampMixin):
            __abstract__ = True  # 테이블 생성 안 함
            id = Column(Integer, primary_key=True)
        
        class Country(BaseModel):  # BaseModel 상속
            __tablename__ = 'countries'  # 실제 테이블 생성
            name = Column(String(100))
        
        • BaseModel ⇒ 테이블 생성 안 됨
        • Countrycountries 테이블 생성
          • id, created_at, updated_at, name 포함


  • func

    • SQLAlchemy의 SQL 함수 네임스페이스
    • 날짜/시간

      func.now()                        # 현재 시간
      func.current_date()               # 현재 날짜
      func.current_timestamp()          # 현재 타임스탬프
      func.date_part('year', date_col)  # 날짜 부분 추출
      

    • 집계 함수

      func.count(User.id)          # COUNT
      func.sum(Order.amount)       # SUM
      func.avg(Product.price)      # AVERAGE
      func.min(User.age)           # MIN
      func.max(User.age)           # MAX
      

    • 문자열 함수

      func.upper(User.name)        # 대문자 변환
      func.lower(User.email)       # 소문자 변환
      func.length(User.name)       # 문자열 길이
      func.concat(User.first_name, ' ', User.last_name)    # 문자열 연결
      

    • 수학 함수

      func.abs(column)             # 절댓값
      func.round(column, 2)        # 반올림
      func.floor(column)           # 내림
      func.ceiling(column)         # 올림
      




app/domains/country/models.py (Country 모델)

from sqlalchemy import Column, Integer, String, DateTime
from sqlalchemy.orm import relationship
from app.common.models.base import BaseModel

class Country(BaseModel):
    __tablename__ = 'countries'

    country_code = Column(String(10), unique=True, nullable=False)
    name = Column(String(100), nullable=False)

    holidays = relationship("Holiday", back_populates="country")
  • Column( ) 함수

    Column(type_, *args, **kwargs)
    
    • 데이터 타입

      Column(Integer)           # 정수형
      Column(String(100))       # 최대 100자 문자열
      Column(Text)              # 긴 텍스트
      Column(DateTime)          # 날짜/시간
      Column(Boolean)           # 불린
      Column(Float)             # 실수
      Column(DECIMAL(10, 2))    # 정밀한 소수점
      

    • 제약 조건

      Column(Integer, primary_key=True)            # 기본키
      Column(String(50), nullable=False)           # NOT NULL
      Column(String(100), unique=True)             # UNIQUE 제약
      Column(Integer, default=0)                   # 기본값
      Column(DateTime, server_default=func.now())  # 서버 기본값
      
      • 서버 기본값이란
        • DB 레벨의 기본값을 말한다.
        • 일반 기본값의 경우, 쿼리를 넘길 때 기본값으로 설정하여 내보내는 것이다.
        • 서버 기본값은 쿼리에서 생략해도 DB 서버가 알아서 NOW()로 입력한다.

    • 인덱스와 성능

      Column(String(50), index=True)        # 인덱스 생성
      Column(Integer, primary_key=True)     # PK는 자동으로 인덱스
      
      • PK와 UK는 자동으로 인덱스 설정된다
      • FK는 DB마다 다르다

    • 업데이트 관련

      Column(DateTime, onupdate=func.now())             # 업데이트 시 자동 갱신
      Column(DateTime, server_onupdate=func.now())      # 서버 레벨 업데이트
      

    • 외래키

      Column(Integer, ForeignKey('users.id'))  # 외래키 설정
      
      • ForeignKey 설정에는 실제 테이블명 사용

    • 기타 옵션

      Column(String(100), 
          comment='사용자 이름',            # 주석
          info={'label': 'Username'},       # 추가 메타데이터
          quote=True,                       # 컬럼명 따옴표 처리
          autoincrement=True)               # 자동 증가 (PK의 기본값)
      
      • autoincrement=True
        • DB별 자동 증가 전략 선택
        • 다만, PK + Integer의 경우 자동 증가가 자동 적용된다.
      • quote=True
        • 예약어나 특수문자가 포함된 컬럼명인 경우 설정

          따옴표 처리로 가능하게 한다



  • relationship( )

    • 기본 관계 설정

      # 현재 코드
      holidays = relationship("Holiday", back_populates="country")
      
      relationship("User")                           # 기본 관계
      relationship("User", backref="orders")         # 역참조 자동 생성
      relationship("User", back_populates="orders")  # 양방향 명시적 설정
      
      • relationship("User", XXX)
        • 대상 클래스는 모델 클래스명을 사용한다.

    • 외래키 관련

      relationship("User", foreign_keys=[user_id])     # 외래키 명시
      relationship("User", primaryjoin="Order.user_id==User.id")  # 조인 조건 명시
      

    • 로딩 전략

      relationship("User", lazy="select")      # 기본: 필요시 로딩
      relationship("User", lazy="joined")      # 조인으로 즉시 로딩  
      relationship("User", lazy="subquery")    # 서브쿼리로 로딩
      relationship("User", lazy="dynamic")     # 쿼리 객체 반환
      relationship("User", lazy="selectin")    # IN 절 사용해서 로딩
      
      • Default : lazy="select"
        • 현업에서 가장 많이 사용
        • 기본값 사용 시 생략 가능
        • 테이블 정의는 기본값으로 두고, 쿼리할 때 상황에 맞게 제어
      • 성능 최적화가 필요할 때 ⇒ N+1 문제 해결
        • lazy="joined"
        • lazy="selectin"

    • 컬렉션 타입

      relationship("Tag", collection_class=set)        # Set으로 관리
      relationship("Tag", collection_class=list)       # List로 관리 (기본값)
      
      • Default : collection_class=list

    • 삭제 동작

      relationship("Order", 
                  cascade="all, delete-orphan",    # 연쇄 삭제
                  passive_deletes=True)            # DB 레벨 삭제
      
      • cascade="all”은 여러 cascade 옵션의 조합이다.
        • save-update : 부모 저장 시 자식도 저장
        • merge : 부모 병합 시 자식도 병합
        • refresh : 부모 새로고침 시 자식도 새로고침
        • expunge : 부모 세션에서 제거 시 자식도 제거
        • delete : 부모 삭제 시 자식도 삭제
      • delete-orphan
        • 부모와 연결이 끊어진 자식 자동 삭제
          user = session.query(User).first()
          order = user.orders[0]  
          user.orders.remove(order)  # 관계에서 제거
          session.commit()
          
          • Order 객체가 자동으로 DB에서 삭제됨
      • passive_deletes=True
        • Default : passive_deletes=False
        • False는 삭제 시 SQLAlchemy가 SELECT로 모든 자식들을 찾아 하나씩 DELETE
          • 느리고 메모리 많이 사용
        • True는 삭제 시 DB의 CASCADE가 알아서 관련 자식 삭제
          • 빠르다

    • 정렬

      relationship("Comment", 
                  order_by="Comment.created_at.desc()")  # 정렬 조건
      
      • JPA의 @OrderBy와 동일
      • 데이터를 가져올 때 정렬 조건

    • 조건부 관계

      class Order(BaseModel):
          # 모든 상품
          product = relationship("Product")
          
          # 활성화된 상품만
          active_product = relationship("Product",
                                      primaryjoin="and_(Order.product_id==Product.id, "
                                              "Product.is_active==True)")
      
      # 사용
      order = session.query(Order).first()
      print(order.product)        # 모든 상품
      print(order.active_product) # 활성화된 상품만 (조건에 맞는 것만)
      

    • 양방향 관계

      # Country 모델
      class Country(BaseModel):
          holidays = relationship("Holiday", back_populates="country")
      
      # Holiday 모델  
      class Holiday(BaseModel):
          country_id = Column(Integer, ForeignKey('countries.id'))
          country = relationship("Country", back_populates="holidays")
      
      • 직접 FK 설정을 해줘야 실제 연결이 된다
        • 즉, 단방향이든 양방향이든 FK 설정이 완료되어야지 실제 연결된다

app/domains/holiday/models.py (Holiday 모델)

from sqlalchemy import Column, Integer, String, Date, Boolean, UniqueConstraint, Index, ForeignKey
from sqlalchemy.orm import relationship
from app.common.models.base import BaseModel

class Holiday(BaseModel):
    __tablename__ = 'holidays'
    __table_args__ = (
        UniqueConstraint('country_id', 'date', 'name', name='uk_1'),

        Index('idx_year', 'holiday_year'),
        Index('idx_country_year', 'country_id', 'holiday_year'),
    )

    date = Column(Date, nullable=False)
    name = Column(String(200))
    local_name = Column(String(200))
    holiday_year = Column(Integer, nullable=False)
    launch_year = Column(Integer)
    is_global = Column(Boolean, default=False)
    types = Column(String(100))
    counties = Column(String(500))

    country_id = Column(Integer, ForeignKey('countries.id'), nullable=False, index=True)

    country = relationship("Country", back_populates="holidays")
  • 1대다 관계 설정

    • Spring Boot의 경우

      • 어노테이션 기반으로 관계 정의
      • mappedBy로 연관관계 주인이 아님을 명시
        • 자동으로 FK는 상대방이 가진다

    • FastAPI의 경우

      • FK 필드를 통해 관계를 결정
      • FK를 가진 테이블이 연관관계의 주인이 된다


  • back_populates vs backref

    • 둘 모두 양방향 설정으로 양쪽에서 모두 연관 대상에 접근할 수 있다

    • backref

      • 한쪽 모델에서만 관계를 정의한다
        • 다른 쪽 모델은 relationship 정의 안 하지만, 자동 생성된다.
      • 간단한 프로젝트에서 주로 사용

    • back_populates

      • 양쪽 모델 모두에서 관계를 정의한다
      • 명시적으로 설정하여 각 관계가 어떻게 설정되어 있는지 명확하다.
      • 또한 각 관계마다 다른 옵션 설정 가능
      • IDE 지원 : 자동완성과 타입 검사 더 잘됨


  • 다대다 관계 설정

    • 단순 다대다 관계 (추가 필드 없는 경우)

      • Spring Boot의 경우

        @Entity
        public class User {
            @ManyToMany
            @JoinTable(
                name = "user_roles",
                joinColumns = @JoinColumn(name = "user_id"),
                inverseJoinColumns = @JoinColumn(name = "role_id")
            )
            private List<Role> roles;
        }
        
        @Entity
        public class Role {
            @ManyToMany(mappedBy = "roles")
            private List<User> users;
        }
        

      • FastAPI의 경우

        from sqlalchemy import Table
        
        user_role_table = Table(
            'user_roles',
            Base.metadata,
            Column('user_id', Integer, ForeignKey('users.id')),
            Column('role_id', Integer, ForeignKey('roles.id'))
        )
        
        class User(BaseModel):
            __tablename__ = 'users'
            name = Column(String(100))
            
            roles = relationship("Role", secondary=user_role_table, back_populates="users")
        
        class Role(BaseModel):
            __tablename__ = 'roles'
            name = Column(String(50))
            
            users = relationship("User", secondary=user_role_table, back_populates="roles")
        
        • 중간 테이블 정의


    • 복잡한 다대다 관계 (중간 테이블에 추가 필드)

      • Spring Boot의 경우

        @Entity
        public class UserRole extends BaseEntity {
            @ManyToOne(fetch = FetchType.LAZY)
            @JoinColumn(name = "user_id")
            private User user;
        
            @ManyToOne(fetch = FetchType.LAZY)
            @JoinColumn(name = "role_id")
            private Role role;
            
            private LocalDateTime assignedDate;
        }
        
        @Entity
        public class User extends BaseEntity {
            @OneToMany(mappedBy = "user", cascade = CascadeType.ALL, orphanRemoval = true)
            private List<UserRole> userRoles = new ArrayList<>();
        }
        
        @Entity
        public class Role extends BaseEntity {
            @OneToMany(mappedBy = "role", cascade = CascadeType.ALL, orphanRemoval = true)
            private List<UserRole> roleUsers = new ArrayList<>();
        }
        

      • FastAPI의 경우

        from sqlalchemy import DateTime, func
        
        class UserRole(BaseModel):
            __tablename__ = 'user_roles'
            
            user_id = Column(Integer, ForeignKey('users.id'), primary_key=True)
            role_id = Column(Integer, ForeignKey('roles.id'), primary_key=True)
            assigned_date = Column(DateTime, server_default=func.now())
            
            user = relationship("User", back_populates="user_roles")
            role = relationship("Role", back_populates="role_users")
        
        class User(BaseModel):
            __tablename__ = 'users'
            name = Column(String(100))
            
            user_roles = relationship("UserRole", back_populates="user")
        
        class Role(BaseModel):
            __tablename__ = 'roles' 
            name = Column(String(50))
            
            role_users = relationship("UserRole", back_populates="role")
        

app/common/crud/base.py

from typing import Type, TypeVar, Any, Optional, List
from sqlalchemy import select, func, text
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import DeclarativeBase

ModelType = TypeVar("ModelType", bound=DeclarativeBase)

class CRUDBase:
    def __init__(self, model: Type[ModelType]):
        self.model = model

    async def get(self, db: AsyncSession, obj_id: Any) -> Optional[ModelType]:
        return await db.get(self.model, obj_id)

    async def get_by_field(self, db: AsyncSession, field_name: str, value: Any) -> Optional[ModelType]:
        stmt = select(self.model).where(getattr(self.model, field_name) == value)
        result = await db.execute(stmt)
        return result.scalars().one_or_none()

    async def get_by_fields(self, db: AsyncSession, **kwargs: Any) -> Optional[ModelType]:
        conditions = [getattr(self.model, key) == value for key, value in kwargs.items()]
        stmt = select(self.model).where(*conditions)
        result = await db.execute(stmt)
        return result.scalars().one_or_none()
    
    async def get_multi(self, db: AsyncSession, *, skip: int = 0, limit: int = 100) -> List[ModelType]:
        stmt = select(self.model).offset(skip).limit(limit)
        result = await db.execute(stmt)
        return result.scalars().all()

    async def get_multi_by_fields(self, db: AsyncSession, *, skip: int = 0, limit: int = 100, **kwargs: Any) -> List[ModelType]:
        conditions = [getattr(self.model, key) == value for key, value in kwargs.items()]
        stmt = select(self.model).where(*conditions).offset(skip).limit(limit)
        result = await db.execute(stmt)
        return result.scalars().all()

    async def get_all(self, db: AsyncSession) -> List[ModelType]:
        stmt = select(self.model)
        results = await db.execute(stmt)
        return results.scalars().all()

    async def count(self, db: AsyncSession) -> int:
        stmt = select(func.count()).select_from(self.model)
        result = await db.execute(stmt)
        return result.scalar_one()

    async def create(self, db: AsyncSession, *, obj_in: Any) -> ModelType:
        db_obj = self.model(**obj_in.model_dump())
        db.add(db_obj)
        await db.commit()
        await db.refresh(db_obj)
        return db_obj

    async def update(self, db: AsyncSession, db_obj: ModelType, obj_in: Any) -> ModelType:
        update_data = obj_in.model_dump(exclude_unset=True)
        for key, value in update_data.items():
            setattr(db_obj, key, value)
        db.add(db_obj)
        await db.commit()
        await db.refresh(db_obj)
        return  db_obj

    async def remove(self, db: AsyncSession, *, db_obj: ModelType) -> ModelType:
        await db.delete(db_obj)
        await db.commit()
        return db_obj




초기 설정

  • TypeVar

    from typing import Type, TypeVar, Any, Optional, List
    from sqlalchemy.orm import DeclarativeBase
    
    ModelType = TypeVar("ModelType", bound=DeclarativeBase)
    
    class CRUDBase:
        def __init__(self, model: Type[ModelType]):
            self.model = model
    
    • TypeVar란

      • Python의 타입 힌팅 시스템에서 제네릭 타입 변수를 정의하는 문법

    • "ModelType"

      • TypeVar의 이름을 문자열로 지정
      • 관례적으로 변수명과 동일하게 설정

    • bound=DeclarativeBase

      • TypeVar가 가질 수 있는 타입의 upper bound을 제한
      • 즉, 이 경우 DeclarativeBase를 상속받은 클래스만 올 수 있다
        • SQLAlchemy ORM 모델들은 모두 DeclarativeBase를 상속받는다


    • 그 외 파라미터

      # 제약 없는 TypeVar
      T = TypeVar('T')
      
      # 특정 타입들로만 제한
      NumberType = TypeVar('NumberType', int, float, complex)
      
      # 공변성/반공변성 설정
      T_co = TypeVar('T_co', covariant=True)              # 공변
      T_contra = TypeVar('T_contra', contravariant=True)  # 반공변
      
      # bound와 constraints 조합
      UserType = TypeVar('UserType', bound=BaseModel, int, str)
      


    • 현업 명명 사례

      # 모델/엔티티 관련
      ModelType = TypeVar('ModelType', bound=BaseModel)
      EntityT = TypeVar('EntityT', bound=Entity)
      DomainT = TypeVar('DomainT')
      
      # 서비스/리포지토리 패턴
      ServiceT = TypeVar('ServiceT', bound=BaseService)
      RepoT = TypeVar('RepoT', bound=BaseRepository)
      
      # HTTP 관련
      RequestT = TypeVar('RequestT', bound=BaseRequest)
      ResponseT = TypeVar('ResponseT', bound=BaseResponse)
      
      # 일반적인 명명
      T = TypeVar('T')  # 가장 일반적
      K = TypeVar('K')  # Key 타입
      V = TypeVar('V')  # Value 타입
      



  • __init__

    • __init__ 함수는 Python의 생성자 메서드
    • __init__ 함수

      • Python의 인스턴스 초기화 메서드 (생성자)
      • 객체가 생성될 때 자동으로 호출
      • 인스턴스의 초기 상태를 설정하는 역할


    • self

      • 인스턴스 자신을 가리키는 참조
      • 인스턴스 메서드의 첫 번째 파라미터로 자동 전달
      • 정적 메서드가 아닌 인스턴스 메서드임을 의미


    • model: Type[ModelType]

      • Type[ModelType]
        • 제네릭 타입 변수 ModelType의 타입을 받는다는 타입 힌트
      • 인스턴스가 아닌 클래스 자체를 받는다




단건 / 다건 조회

  • SQLAlchemy ORM

    • 클래스는 테이블의 메타데이터를 담고 있는 설계도 역할을 한다.
    • 객체를 통해 메서드를 호출하지만 내부에서는 클래스 자체만을 사용



  • 단건 조회

    async def get(self, db: AsyncSession, obj_id: Any) -> Optional[ModelType]:
        return await db.get(self.model, obj_id)
    
    async def get_by_field(self, db: AsyncSession, field_name: str, value: Any) -> Optional[ModelType]:
        stmt = select(self.model).where(getattr(self.model, field_name) == value)
        result = await db.execute(stmt)
        return result.scalars().one_or_none()
    
    async def get_by_fields(self, db: AsyncSession, **kwargs: Any) -> Optional[ModelType]:
        conditions = [getattr(self.model, key) == value for key, value in kwargs.items()]
        stmt = select(self.model).where(*conditions)
        result = await db.execute(stmt)
        return result.scalars().one_or_none()
    
    • SQLAlchemy의 get( ) 함수

      • Primary Key로 단일 객체 조회하는 SQLAlchemy 메서드
      • db.get(self.model, obj_id)
      • 첫 번째 인자 : 조회할 모델 클래스
      • 두 번째 인자 : Primary Key 값
      • 반환 : 해당 객체 또는 None


    • SQLAlchemy의 select( ) 함수

      stmt = select(User)  # SELECT * FROM user
      stmt = select(User.name, User.email)  # SELECT name, email FROM user
      
      • 쿼리 빌더 함수로, SQL SELECT 문을 생성하는 SQLAlchemy 메서드


    • SQLAlchemy의 where( ) 함수

      # 단일 조건
      stmt = select(User).where(User.id == 1)
      
      # 여러 조건 (AND)
      stmt = select(User).where(User.active == True, User.age > 18)
      
      # 또는 (OR)
      stmt = select(User).where(User.active == True).where(User.age > 18)
      
      • SELECT 문에 WHERE 조건을 추가하는 SQLAlchemy 메서드


    • getattr( ) 함수

      class User:
          name = "John"
          age = 25
      
      # 일반적인 속성 접근
      user.name  # "John"
      
      # getattr을 사용한 동적 속성 접근
      getattr(user, "name")  # "John"
      getattr(user, "salary", 0)  # 기본값 0 (속성이 없을 때)
      
      • 객체의 속성을 동적으로 가져오는 Python 내장 함수
      • 해당 속성이 없을 때 기본값을 지정해서 가져올 수 있다.


    • **kwargs 문법

      def get_by_fields(self, db: AsyncSession, **kwargs: Any):
          # kwargs = {"name": "John", "age": 25, "active": True}
          print(kwargs)  # {'name': 'John', 'age': 25, 'active': True}
          
          # 딕셔너리 순회
          for key, value in kwargs.items():
              print(f"{key}: {value}")
      
      # 사용 예시
      repo.get_by_fields(db, name="John", age=25, active=True)
      
      • 키워드 인자들을 딕셔너리로 받는 Python 문법
      • 딕셔너리 순회를 통해 많이 사용한다.
        • dict 클래스의 메서드인 items() 등 활용
      • 2개 이상의 필드로 조건부 조회하는 경우 정확한 숫자를 몰라도 설정 가능



  • SQLAlchemy의 execute( ) 함수

    # 비동기 실행
    result = await db.execute(stmt)
    
    # 동기 실행
    result = db.execute(stmt)
    
    • 데이터베이스에 쿼리를 실행하는 SQLAlchemy 메서드

    • execute() 결과의 기본 구조

      # 직접 fetchall() 사용
      rows = result.fetchall()
      print(type(rows[0]))  # <class 'sqlalchemy.engine.Row'>
      
      # Row 객체 접근
      for row in rows:
          # 방법 1: 인덱스로 접근
          print(row[0])  # 첫 번째 컬럼
          
          # 방법 2: 속성명으로 접근
          print(row.id)     # User.id
          print(row.name)   # User.name
          
          # 방법 3: 딕셔너리처럼 접근
          print(row['id'])
          print(row['name'])
          
          # Row를 딕셔너리로 변환
          user_dict = row._asdict()
          print(user_dict)  # {'id': 1, 'name': 'John', 'email': 'john@example.com'}
      
    • Row 객체는 ORM 모델 객체의 메서드를 사용할 수 없다.



  • SQLAlchemy의 scalars( ) 함수

    # ORM 객체
    result = await db.execute(select(User))
    users = result.scalars().all()
    user_obj = users[0]
    print(type(user_obj))
    user_obj.some_method()  # User 모델의 메서드 사용 가능
    
    • Result 객체에서 스칼라 값들만 추출하는 SQLAlchemy 메서드

    • scalars( ) 대신 사용할 수 있는 메서드들

      • mappings() - 딕셔너리 형태로 결과 반환

        result = await db.execute(select(User))
        user_dicts = result.mappings().all()
        # [{'id': 1, 'name': 'John', 'email': 'john@example.com'}, ...]
        
        # API 응답에 바로 사용 가능
        return JSONResponse(content=user_dicts)
        

      • 튜플 결과

        result = await db.execute(select(User.id, User.name))
        tuples = result.all()  # [(1, 'John'), (2, 'Jane'), ...]
        
        # 또는 scalars().all()로 첫 번째 컬럼만
        result = await db.execute(select(User.name))
        names = result.scalars().all()  # ['John', 'Jane', ...]
        


    • 뒤에 올 수 있는 메서드

      # 모든 결과 가져오기
      .all()          # 리스트로 모든 결과 반환
      .fetchall()     # all()과 동일
      
      # 단일 결과 가져오기
      .first()        # 첫 번째 결과 또는 None
      .one()          # 정확히 하나의 결과 (없거나 여러 개면 예외)
      .one_or_none()  # 하나의 결과 또는 None (여러 개면 예외)
      
      # 고유값들만
      .unique()       # 중복 제거
      
      # 파티셔닝
      .partitions(size)  # 지정된 크기로 결과를 나눔
      
      • all( ) 메서드의 특징
        • 데이터가 없는 경우 빈 리스트 [] 반환 (None이 아님)
        • 항상 리스트 타입 보장




  • SQLAlchemy의 scalar( ) 함수

    • 쿼리 결과의 첫 번째 행, 첫 번째 컬럼의 값을 추출
      result.scalar()              # 첫 번째 행의 첫 번째 컬럼 (None 가능)
      result.scalar_one()          # 정확히 1개 값 (0개나 2개 이상이면 예외)
      result.scalar_one_or_none()  # 1개 값 또는 None (2개 이상이면 예외)
      


    • 사용 대상

      • COUNT, SUM, MAX, MIN 등 집계 함수 결과
      • 단일 값을 반환하는 쿼리
      • 존재 여부 확인 (EXISTS)


    • scalar( )와 scalars( )의 차이점

      • scalar() - 단일 값만 반환
      • scalars().first() - ORM 객체 하나
      • scalars().all() - ORM 객체 리스트
      • 즉, 일반 상세 조회 등에서는 scalar( )를 쓰지 않는다. 객체가 필요하기 때문에


    • 사용 예시

      # 1. COUNT - 가장 많이 사용
      async def get_user_count(db: AsyncSession) -> int:
          result = await db.execute(select(func.count(User.id)))
          return result.scalar()  # 정수 반환 (예: 150)
      
      # 2. 조건부 COUNT
      async def get_active_user_count(db: AsyncSession) -> int:
          result = await db.execute(
              select(func.count(User.id)).where(User.active == True)
          )
          return result.scalar()
      
      # 3. MAX, MIN 값 조회
      async def get_latest_user_id(db: AsyncSession) -> Optional[int]:
          result = await db.execute(select(func.max(User.id)))
          return result.scalar()  # 최대 ID 값
      
      # 4. SUM 계산
      async def get_total_order_amount(db: AsyncSession, user_id: int) -> float:
          result = await db.execute(
              select(func.sum(Order.amount)).where(Order.user_id == user_id)
          )
          return result.scalar() or 0.0
      
      # 5. EXISTS 확인 (매우 자주 사용)
      async def email_exists(db: AsyncSession, email: str) -> bool:
          result = await db.execute(
              select(exists().where(User.email == email))
          )
          return result.scalar()  # True/False
      
      # 6. 단일 컬럼 값 하나만 가져오기
      async def get_user_name(db: AsyncSession, user_id: int) -> Optional[str]:
          result = await db.execute(
              select(User.name).where(User.id == user_id)
          )
          return result.scalar()  # 이름 문자열 또는 None
      



  • 다건 조회

    async def get_multi(self, db: AsyncSession, *, skip: int = 0, limit: int = 100) -> List[ModelType]:
        stmt = select(self.model).offset(skip).limit(limit)
        result = await db.execute(stmt)
        return result.scalars().all()
    
    async def get_multi_by_fields(self, db: AsyncSession, *, skip: int = 0, limit: int = 100, **kwargs: Any) -> List[ModelType]:
        conditions = [getattr(self.model, key) == value for key, value in kwargs.items()]
        stmt = select(self.model).where(*conditions).offset(skip).limit(limit)
        result = await db.execute(stmt)
        return result.scalars().all()
    
    async def get_all(self, db: AsyncSession) -> List[ModelType]:
            stmt = select(self.model)
            results = await db.execute(stmt)
            return results.scalars().all()
    
    • offset( )과 limit( ) 함수

      stmt = select(User).offset(10).limit(5)
      # 생성되는 SQL: SELECT * FROM user LIMIT 5 OFFSET 10
      # 10개를 건너뛰고 5개만 가져와라
      
      • SQLAlchemy의 페이징 처리를 위한 SQL 함수들


    • * 키워드 전용 인자 문법

      def func(a, b, *, c, d):
          pass
      
      # 올바른 호출
      func(1, 2, c=3, d=4)  # ✅
      
      # 잘못된 호출  
      func(1, 2, 3, 4)      # ❌ TypeError: 위치 인자로 c, d 전달 불가
      
      • * 이후의 매개변수들반드시 키워드로 전달해야 한다


      • 키워드 인자와 위치 인자

        • 위치 인자의 경우, 순서에 따라 매개변수에 값이 할당된다.

          greet("김철수", 25, "서울")     # ✅
          greet(25, "서울", "김철수")     # ❌ or 의도와 다른 결과
          
        • 키워드 인자의 경우, 매개변수 이름을 명시하여 할당한다.

          greet(name="김철수", age=25, city="서울")  # ✅
          greet(city="서울", name="김철수", age=25)  # ✅
          

      • 이렇게 키워드 인자로 사용하는 이유

        users = await crud.get_multi(db, skip=20, limit=10)
        
        • 무엇을 건너뛰고 몇 개 가져올지 명확하게 명시하기 위해




집계 함수

async def count(self, db: AsyncSession) -> int:
    stmt = select(func.count()).select_from(self.model)
    result = await db.execute(stmt)
    return result.scalar_one()

async def count_by_fields(self, db: AsyncSession, **kwargs: Any) -> int:
    conditions = [getattr(self.model, key) == value for key, value in kwargs.items()]
    stmt = select(func.count()).select_from(self.model).where(*conditions)
    result = await db.execute(stmt)
    return result.scalar_one()
  • 사용 예시

    total_users = await user_crud.count(db)  # 전체 사용자 수
    active_users = await user_crud.count_by_fields(db, active=True)  # 활성 사용자 수
    


  • scalar_one( )

    • scalar_one()

      • 오직 단 하나의 값만 나와야 한다
      • scalar_one()을 사용하는 이유는 COUNT 함수는 항상 숫자를 반환하기 때문
      • 조건에 맞는 데이터가 없는 경우 null이 아닌 0을 반환


    • 다른 집계 함수

      # MAX, MIN, SUM, AVG - NULL 가능
      SELECT MAX(age) FROM users;           -- 데이터 없으면 NULL
      SELECT SUM(salary) FROM users;        -- 데이터 없으면 NULL  
      SELECT AVG(age) FROM users;           -- 데이터 없으면 NULL
      
      • 그러므로, 다른 집계함수에서는 scalar()를 쓰거나 scalar_one_or_none() 사용
      • scalar()결과가 여러 개여도 첫번째 것만 가져온다
      • scalar_one_or_none()결과가 여러 개라면 예외를 발생시킨다



  • SQLAlchemy의 select_from( ) 함수

    select(User)  
    # SQL: SELECT * FROM user
    
    select(func.count()).select_from(User)
    # SQL: SELECT COUNT(*) FROM user
    
    • FROM 절을 명시적으로 지정하는 함수
      • 일반적인 select() 함수는 FROM이 자동 추론된다.

      • 다만, 집계 함수를 쓰는 경우 FROM 절이 불분명한 경우가 있다

      • 컬럼을 명시하는 경우select_from 불필요

        stmt = select(func.count(User.id))
        
      • 제네릭 타입이므로 컬럼을 명시하기 힘들다


    • 서브 쿼리로 FROM 절을 사용하는 경우select_from 필요

      stmt = select(func.count()).select_from(select(User).subquery())
      # SQL: SELECT COUNT(*) FROM (SELECT * FROM user) AS subquery
      




생성 / 수정 / 삭제

async def create(self, db: AsyncSession, *, obj_in: Any) -> ModelType:
    db_obj = self.model(**obj_in.model_dump())
    db.add(db_obj)
    await db.commit()
    await db.refresh(db_obj)
    return db_obj

async def update(self, db: AsyncSession, db_obj: ModelType, obj_in: Any) -> ModelType:
    update_data = obj_in.model_dump(exclude_unset=True)
    for key, value in update_data.items():
        setattr(db_obj, key, value)
    db.add(db_obj)
    await db.commit()
    await db.refresh(db_obj)
    return  db_obj

async def remove(self, db: AsyncSession, *, db_obj: ModelType) -> ModelType:
    await db.delete(db_obj)
    await db.commit()
    return db_obj
  • obj_in

    • Pydantic Schemas가 들어오기 때문Any로 설정
    • Spring Boot에서 DTO를 받아 전환 후 DB 저장하는 패턴과 유사



  • model_dump( ) 함수

    • Pydantic BaseModel 클래스의 메서드
    • Pydantic 객체를 딕셔너리로 변환한다.

      # Pydantic 객체 생성
      user_data = UserCreate(name="김철수", email="kim@email.com", age=30)
      
      # model_dump() 호출 - 딕셔너리로 변환
      print(user_data.model_dump())
      # 출력: {'name': '김철수', 'email': 'kim@email.com', 'age': 30, 'active': True}
      


    • model_dump( )의 매개변수들

      • 기본 사용
        user_data.model_dump()
        # {'name': '김철수', 'email': 'kim@email.com', 'age': 30, 'active': True}
        

      • exclude_unset=True
        user_data.model_dump(exclude_unset=True)
        # {'name': '김철수', 'email': 'kim@email.com', 'age': 30}  # active는 기본값이므로 제외
        
        • 설정되지 않은 필드 제외
        • 기본값 : False
        • 일부 필드만 업데이트하는 경우 중요하다

      • include
        user_data.model_dump(include={'name', 'email'})
        # {'name': '김철수', 'email': 'kim@email.com'}
        
        • 특정 필드만 포함

      • exclude
        user_data.model_dump(exclude={'active'})
        # {'name': '김철수', 'email': 'kim@email.com', 'age': 30}
        
        • 특정 필드 제외

      • by_alias=True
        user_data.model_dump(by_alias=True)
        
        • 필드 별칭 사용

      • exclude_none=True
        user_data.model_dump(exclude_none=True)
        
        • None 값 제외



  • refresh( ) 함수

    • 데이터베이스와 객체를 동기화한다
    • 데이터 삽입 과정

      async def create(self, db: AsyncSession, *, obj_in: Any) -> ModelType:
          # 1. Pydantic → SQLAlchemy 객체 생성
          db_obj = self.model(**obj_in.model_dump())
          
          print(db_obj.id)  # None (아직 DB에 저장 전)
          
          # 2. 세션에 추가
          db.add(db_obj)
          print(db_obj.id)  # 여전히 None
          
          # 3. DB에 커밋 (실제 저장)
          await db.commit()
          print(db_obj.id)  # 여전히 None!
          
          # 4. 객체 새로고침 (DB에서 최신 데이터 가져옴)
          await db.refresh(db_obj)
          print(db_obj.id)  # 1 (DB에서 자동 생성된 ID)
          
          return db_obj
      
      1. Pydantic의 model_dump() 함수로 딕셔너리 전환
      2. 딕셔너리를 self.model(해당 모델)의 init 함수로 넘겨 모델 객체 생성
      3. db.add()로 DB 세션에 저장
      4. db.commit()을 통해 실제 DB에 저장
      5. db.refresh()를 통해 해당 객체의 정보를 DB에서 갱신
      6. Database의 종류가 무엇이든, SQLAlchemy는 DB 저장 후 ID를 가져온다



  • setattr( ) 함수

    class User:
        def __init__(self):
            self.name = ""
            self.age = 0
    
    user = User()
    
    # 일반적인 속성 설정
    user.name = "김철수"
    user.age = 25
    
    # setattr()을 사용한 동적 속성 설정
    setattr(user, "name", "김철수")  # user.name = "김철수"와 같음
    setattr(user, "age", 25)         # user.age = 25와 같음
    
    • getattr() 함수는 동적으로 객체의 속성의 값을 가져오는 것이다
    • setattr()동적으로 객체의 속성의 값을 설정하는 것이다

app/common/crud/base.py (수정)

from typing import Type, TypeVar, Any, Optional, List
from sqlalchemy import select, func
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import DeclarativeBase

ModelType = TypeVar("ModelType", bound=DeclarativeBase)

class CRUDBase:
    def __init__(self, model: Type[ModelType]):
        self.model = model

    async def get(self, db: AsyncSession, obj_id: Any) -> Optional[ModelType]:
        return await db.get(self.model, obj_id)

    async def get_by_field(self, db: AsyncSession, field_name: str, value: Any) -> Optional[ModelType]:
        stmt = select(self.model).where(getattr(self.model, field_name) == value)
        result = await db.execute(stmt)
        return result.scalars().one_or_none()

    async def get_by_fields(self, db: AsyncSession, **kwargs: Any) -> Optional[ModelType]:
        conditions = [getattr(self.model, key) == value for key, value in kwargs.items()]
        stmt = select(self.model).where(*conditions)
        result = await db.execute(stmt)
        return result.scalars().one_or_none()

    async def get_multi(self, db: AsyncSession, *, skip: int = 0, limit: int = 100) -> List[ModelType]:
        stmt = select(self.model).offset(skip).limit(limit)
        result = await db.execute(stmt)
        return result.scalars().all()

    async def get_multi_by_fields(self, db: AsyncSession, *, skip: int = 0, limit: int = 100, **kwargs: Any) -> List[ModelType]:
        conditions = [getattr(self.model, key) == value for key, value in kwargs.items()]
        stmt = select(self.model).where(*conditions).offset(skip).limit(limit)
        result = await db.execute(stmt)
        return result.scalars().all()

    async def get_all(self, db: AsyncSession) -> List[ModelType]:
        stmt = select(self.model)
        results = await db.execute(stmt)
        return results.scalars().all()

    async def count(self, db: AsyncSession) -> int:
        stmt = select(func.count()).select_from(self.model)
        result = await db.execute(stmt)
        return result.scalar_one()

    async def count_by_fields(self, db: AsyncSession, **kwargs: Any) -> int:
        conditions = [getattr(self.model, key) == value for key, value in kwargs.items()]
        stmt = select(func.count()).select_from(self.model).where(*conditions)
        result = await db.execute(stmt)
        return result.scalar_one()

    async def create(self, db: AsyncSession, *, obj_in: Any) -> ModelType:
        db_obj = self.model(**obj_in.model_dump())
        db.add(db_obj)
        return db_obj

    async def update(self, db: AsyncSession, db_obj: ModelType, obj_in: Any) -> ModelType:
        update_data = obj_in.model_dump(exclude_unset=True)
        for key, value in update_data.items():
            setattr(db_obj, key, value)
        db.add(db_obj)
        return  db_obj

    async def remove(self, db: AsyncSession, *, db_obj: ModelType) -> ModelType:
        await db.delete(db_obj)
        return db_obj
  • FastAPI는 트랜잭션을 DB 세션으로 관리한다
    • Spring Boot 처럼 관리되지 않고 직접 관리해야 한다.

  • 문제점

    • 현재의 CRUDBase 클래스는 트랜잭션이 각각 설정되어 있는 것이 문제이다.
    • 비즈니스 로직 하나에 여러 모델을 생성하거나 수정, 삭제할 수 있는데 전체의 비즈니스 로직이 원자성을 가져야 한다
    • 하지만, 이 경우 개별적으로 트랜잭션이 종료되므로 서비스 레벨에서 트랜잭션을 관리해야 한다

  • 기존의 create, update, remove 함수commit()refresh()를 제거

  • 단, create 함수의 결과 객체에는 ID가 존재하지 않는다는 것을 명심해야 한다.

    • ID가 필요한 경우 Service Layer에서 중간 db.flush()를 해야 한다.




app/common/service/base.py

from typing import TypeVar, Any, Optional, List
from sqlalchemy.ext.asyncio import AsyncSession
from app.common.crud.base import CRUDBase, ModelType

ServiceType = TypeVar("ServiceType", bound="BaseService")

class BaseService:
    def __init__(self, crud: CRUDBase):
        self.crud = crud

    async def get_by_id(self, db: AsyncSession, obj_id: Any) -> Optional[ModelType]:
        return await self.crud.get(db, obj_id)

    async def get_by_field(self, db: AsyncSession, field_name: str, value: Any) -> Optional[ModelType]:
        return await self.crud.get_by_field(db, field_name, value)

    async def get_by_fields(self, db: AsyncSession, **kwargs: Any) -> Optional[ModelType]:
        return await self.crud.get_by_fields(db, **kwargs)

    async def get_multi(self, db: AsyncSession, *, skip: int = 0, limit: int = 100) -> List[ModelType]:
        return await self.crud.get_multi(db, skip=skip, limit=limit)

    async def get_multi_by_fields(self, db: AsyncSession, *, skip: int = 0, limit: int = 100, **kwargs: Any) -> List[ModelType]:
        return await self.crud.get_multi_by_fields(db, skip=skip, limit=limit, **kwargs)

    async def get_all(self, db: AsyncSession) -> List[ModelType]:
        return await self.crud.get_all(db)

    async def count(self, db: AsyncSession) -> int:
        return await self.crud.count(db)

    async def count_by_fields(self, db: AsyncSession, **kwargs: Any) -> int:
        return await self.crud.count_by_fields(db, **kwargs)

    async def create(self, db: AsyncSession, *, obj_in: Any) -> ModelType:
        db_obj = await self.crud.create(db, obj_in=obj_in)
        await db.commit()
        await db.refresh(db_obj)
        return db_obj

    async def update(self, db: AsyncSession, db_obj: ModelType, obj_in: Any) -> ModelType:
        updated_obj = await self.crud.update(db, db_obj=db_obj, obj_in=obj_in)
        await db.commit()
        await db.refresh(updated_obj)
        return updated_obj

    async def remove(self, db: AsyncSession, *, db_obj: ModelType) -> ModelType:
        removed_obj = await self.crud.remove(db, db_obj=db_obj)
        await db.commit()
        return removed_obj
  • 간단한 비즈니스 로직을 처리할 수 있는 Base 서비스이다.

    async def create(self, db: AsyncSession, *, obj_in: Any) -> ModelType:
        db_obj = await self.crud.create(db, obj_in=obj_in)
        await db.commit()
        await db.refresh(db_obj)
        return db_obj
    
    async def update(self, db: AsyncSession, db_obj: ModelType, obj_in: Any) -> ModelType:
        updated_obj = await self.crud.update(db, db_obj=db_obj, obj_in=obj_in)
        await db.commit()
        await db.refresh(updated_obj)
        return updated_obj
    
    async def remove(self, db: AsyncSession, *, db_obj: ModelType) -> ModelType:
        removed_obj = await self.crud.remove(db, db_obj=db_obj)
        await db.commit()
        return removed_obj
    
    • BaseCRUD에서 처리하지 못한 create, update, remove 함수트랜잭션 처리를 Service에서 처리한다

  • 모델 특화 서비스는 이 BaseService를 상속 받아 구현한다

app/core/deps.py

from fastapi import Depends
from sqlalchemy.ext.asyncio import AsyncSession
from database import get_db
from app.domains.country.services import CountryService
from app.domains.country.crud import country_crud

def get_country_service(db: AsyncSession = Depends(get_db)) -> CountryService:
    return CountryService(country_crud)
  • import가 아닌 의존성 주입의 장점
    • Depends의존성 주입 체인을 만든다
    • 테스트 용이성
      • 라우터를 테스트할 때 Mock 객체로 쉽게 전환이 가능하다.
    • 객체 생성의 책임 분리



app/domains/country/crud.py

from models import Country
from app.common.crud.base import CRUDBase

class CRUDCountry(CRUDBase):
    def __init__(self, model: type[Country]):
        super().__init__(model)

country_crud = CRUDCountry(Country)
  • CRUDBase 상속 받아서 구성
  • 생성자와 super()를 통해 상위 클래스에 필요 인자를 전달한다.



app/domains/country/services.py

from app.common.services.base import BaseService
from crud import CRUDCountry

class CountryService(BaseService):
    def __init__(self, crud: CRUDCountry):
        super().__init__(crud)
  • BaseService 상속 받아서 구성
  • BaseService에서 type으로 클래스로 받지 않고 객체로 받고 있기 때문에 객체로 지정

app/domains/country/schemas.py

from pydantic import BaseModel, ConfigDict
from typing import Optional

class CountryBase(BaseModel):
    country_code: str

    model_config = ConfigDict(from_attributes=True)

class CountryCreate(CountryBase):
    name: str

class CountryUpdate(CountryCreate):
    name: Optional[str] = None
    country_code: Optional[str] = None

class CountryListItem(CountryBase):
    id: int

class CountryResponse(CountryCreate):
    id: int
    created_at: datetime
    updated_at: datetime
  • CountryBase

    • country_code 필드만 정의한 이유

      • 가장 기본적인 속성이며, 여러 DTO에서 공통으로 사용되기 때문
      • name 필드의 경우 필요 없는 경우가 있다

    • 이렇게 상위 클래스를 만드는 이유

      • 공통 필드 관리
      • model_config 설정 상속


  • CountryUpdate

    • Pydantic은 상속받은 클래스에서 필드를 오버라이딩하는 것을 허용



app/common/schemas/base.py

from pydantic import BaseModel
from typing import Generic, TypeVar, Optional

T = TypeVar("T")

class BaseResponse(BaseModel, Generic[T]):
    code: str
    message: str
    data: Optional[T] = None
  • 공통 응답 설정
    • Router에서 response_model로 설정한 Pydantic 모델과 Router 메서드의 반환값이 다른 경우 서버 오류를 반환
    • 따라서, 상태코드, 메시지 등과 함께 응답하기 위해선 공통 응답을 설정해야 한다.



app/domains/country/router.py

from fastapi import APIRouter
from fastapi.params import Depends
from sqlalchemy.ext.asyncio import AsyncSession
from typing import List
from app.core.database import get_db
from app.core.deps import get_country_service
from app.common.schemas.base import BaseResponse
from services import CountryService
from schemas import CountryListItem

router = APIRouter(prefix="/countries")

@router.get("/", response_model=BaseResponse[List[CountryListItem]])
async def get_countries(
    country_service: CountryService = Depends(get_country_service),
    db: AsyncSession = Depends(get_db)
):
    countries = await country_service.get_all(db=db)
    response = [CountryListItem.model_validate(country) for country in countries]
    return {
        "code": "200",
        "message": "OK",
        "data": response
    }
  • model_validate( ) 함수

    • Pydantic BaseModel 클래스의 핵심 정적 메서드
    • 다양한 형태의 입력 데이터를 Pydantic 모델로 변환하고 유효성을 검사
    • Pydantic 모델에 from_attributes=True필수적으로 설정되어 있어야 한다.
    • 작동 원리
      1. 입력 데이터 수용
        • 딕셔너리, JSON, SQLAlchemy 모델 객체 등 다양한 형태
      2. from_attributes=True 활성화
        • 모델의 필드들을 속성 이름으로 매핑하고 Pydantic 객체를 생성
      3. 속성 기반 변환 후 유효성 검사
        • Pydantic 모델에 설정된 타입 힌트 참고

  • 리스트 컴프리헨션(List Comprehension) 문법

    • [변환된_값 for 요소 in 리스트_또는_이터러블]

    • 파이썬에서 반복문과 조건문을 사용하여 리스트를 한 줄로 간결하게 생성하는 문법
    • 리스트 컴프리헨션 예시

      countries_data = []
      for country in countries:
          converted_country = CountryListItem.model_validate(country)
          countries_data.append(converted_country)
      
      countries_data = [CountryListItem.model_validate(country) for country in countries]
      


  • Depend( )

    • FastAPI에서는 Depend( )를 통해 의존성을 주입한다
    • 이렇게 주입하는 것이 결합도 낮추고 테스트 용이하다.



app/main.py (수정)

from fastapi import FastAPI

from app.core.config import settings
from app.core.middleware import add_middleware
from app.domains.country.router import router as country_router

from app.domains.country.models import Country
from app.domains.holiday.models import Holiday

app = FastAPI(
    title="Holiday Keeper API",
    description="공휴일 정보 관리 API",
    version="1.0.0",
    docs_url="/docs",
    redoc_url="/redoc"
)

add_middleware(app)

app.include_router(country_router, prefix="/api/v1", tags=["countries"])

@app.get("/")
async def root():
    return {"message": "Holiday Keeper API is running!"}

@app.get("/health")
async def health_check():
    return {
        "status": "healthy",
        "service": "holiday-keeper-api",
        "version": "1.0.0"
    }

@app.get("/test")
async def test():
    return {
        "message": "테스트 성공",
        "data": ["공휴일1", "공휴일2", "공휴일3"]
    }

@app.get("/config")
async def get_config():
    if not settings.is_dev:
        return {"error": "개발 환경 전용 API"}

    return {
        "project_name": settings.project_name,
        "version": settings.version,
        "environment": settings.environment,
        "server_host": settings.server_host,
        "server_port": settings.server_port,
        "cors_origins": settings.cors_origins_list,
        "log_level": settings.log_level
    }

  • app.include_router()를 통해 router를 등록한다.

  • Model Import 하기

    from app.domains.country.models import Country
    from app.domains.holiday.models import Holiday
    
    • Alembic이 autogenerate 기능을 위해 사용하는 SQLAlchemy의 Base.metadata 객체는 모든 모델을 알고 있어야 한다.
    • SQLAlchemy는 모델들을 자동으로 찾아서 등록하지 않는다
    • 따라서 모델이 정의된 Python 모듈을 직접 import해야 Base.metadata에 등록된다.

Alembic 초기화와 마이그레이션

  • 초기화 (Initialization)

    • Alembic을 최초로 설정하는 과정
    • 프로젝트당 딱 한 번만 실행

    • 실행 방법
      alembic init alembic
      
      • 루트 디렉토리에서 명령어 입력
      • 결과로 아래의 파일이 생성된다.
        • alembic/ 디렉터리
        • alembic.ini

  • 마이그레이션 (Migration)

    • 데이터베이스 스키마를 변경하는 과정
    • 모델에 필드를 추가하거나, 테이블을 삭제하는 등 실제 데이터베이스 변경 사항이 있을 때마다 하는 과정
      • alembic revision 명령어로 스크립트를 생성
      • alembic upgrade 명령어로 그 변경 사항을 적용



alembic.ini (수정)

# database URL.  This is consumed by the user-maintained env.py script only.
# other means of configuring database URLs may be customized within the env.py
# file.
sqlalchemy.url = driver://user:pass@localhost/dbname

[post_write_hooks]
  • SQLAlchemy 데이터베이스 URL 수정
    • alembic.ini 파일에서 sqlalchemy.url = 부분을 찾는다.

      # database URL.  This is consumed by the user-maintained env.py script only.
      # other means of configuring database URLs may be customized within the env.py
      # file.
      sqlalchemy.url = mysql+pymysql://root:1234@127.0.0.1:3306/hk_fast_db?charset=utf8mb4
      
      [post_write_hooks]
      



alembic/env.py (수정)

from logging.config import fileConfig

from sqlalchemy import engine_from_config
from sqlalchemy import pool

from alembic import context

from app.core.config import settings

# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.
config = context.config

# Interpret the config file for Python logging.
# This line sets up loggers basically.
if config.config_file_name is not None:
    fileConfig(config.config_file_name)

# add your model's MetaData object here
# for 'autogenerate' support
# from myapp import mymodel
# target_metadata = mymodel.Base.metadata
from app.common.models.base import Base
target_metadata = Base.metadata

# other values from the config, defined by the needs of env.py,
# can be acquired:
# my_important_option = config.get_main_option("my_important_option")
# ... etc.

def run_migrations_offline() -> None:
    """Run migrations in 'offline' mode.

    This configures the context with just a URL
    and not an Engine, though an Engine is acceptable
    here as well.  By skipping the Engine creation
    we don't even need a DBAPI to be available.

    Calls to context.execute() here emit the given string to the
    script output.

    """
    url = config.get_main_option("sqlalchemy.url")
    context.configure(
        url=url,
        target_metadata=target_metadata,
        literal_binds=True,
        dialect_opts={"paramstyle": "named"},
    )

    with context.begin_transaction():
        context.run_migrations()

def run_migrations_online() -> None:
    """Run migrations in 'online' mode.

    In this scenario we need to create an Engine
    and associate a connection with the context.

    """
    connectable = engine_from_config(
        {'sqlalchemy.url': settings.sync_database_url},
        prefix="sqlalchemy.",
        poolclass=pool.NullPool,
    )

    with connectable.connect() as connection:
        context.configure(
            connection=connection, target_metadata=target_metadata
        )

        with context.begin_transaction():
            context.run_migrations()

if context.is_offline_mode():
    run_migrations_offline()
else:
    run_migrations_online()
  • Alembic이 SQLAlchemy 모델을 인식하도록 수정해야 한다.

  • target_metadata = None 설정

    • target_metadata = None을 SQLAlchemy의 Base 클래스를 임포트로 변경

      # add your model's MetaData object here
      # for 'autogenerate' support
      # from myapp import mymodel
      # target_metadata = mymodel.Base.metadata
      from app.common.models.base import Base
      from app.domains.country.models import Country
      from app.domains.holiday.models import Holiday
      target_metadata = Base.metadata
      
      # other values from the config, defined by the needs of env.py,
      
      • SQLAlchemy의 Base 클래스를 임포트해줘야 인식 가능


  • run_migrations_online( )

    • Alembic이 데이터베이스에 직접 연결하여 마이그레이션 스크립트를 실행
    • 주로 개발 단계나 실제 운영 서버에서 사용

    • alembic upgrade head 명령어를 실행하면 이 함수가 호출된다

      (우리가 사용할 것)



  • run_migrations_offline( )

    • Alembic이 데이터베이스에 직접 연결하지 않고, SQL 문을 문자열로 생성하여 출력하는 방식
    • 데이터베이스 접근이 불가능하거나 제한적인 환경(예: 배포 파이프라인)에서 사용
    • 이렇게 생성된 SQL 파일을 나중에 직접 데이터베이스에 적용할 수 있다
    • alembic upgrade head --sql 명령어를 실행하면 이 함수가 호출된다


  • settings.sync_database_url 사용하기

    # 수정 전
    def run_migrations_online() -> None:
        # ...
        connectable = engine_from_config(
            config.get_section(config.config_ini_section, {}),
            prefix='sqlalchemy.',
            poolclass=pool.NullPool,
        )
    
        with connectable.connect() as connection:
            context.configure(
                connection=connection, target_metadata=target_metadata
            )
    
            with context.begin_transaction():
                context.run_migrations()
    
    # 수정 후
    from app.core.config import settings
    
    def run_migrations_online() -> None:
        # ...
        connectable = engine_from_config(
            {'sqlalchemy.url': settings.sync_database_url},
            prefix="sqlalchemy.",
            poolclass=pool.NullPool,
        )
    
        with connectable.connect() as connection:
            context.configure(
                connection=connection, target_metadata=target_metadata
            )
    
            with context.begin_transaction():
                context.run_migrations()
    
    • 보안 강화를 위해서 alembic.ini에서 sqlalchemy.url을 읽어와서 데이터베이스에 연결하지 않고 Pydantic BaseSettings가 관리하는 settings.sync_database_url에서 직접 가져온다.
    • 보안 강화
      • sqlalchemy.url.ini 파일에 직접 노출하지 않고, .env 파일을 통해 환경 변수로 관리
    • 일관성 있게 환경 변수를 통해 관리한다


  • alembic.ini 파일 수정하기

    # database URL.  This is consumed by the user-maintained env.py script only.
    # other means of configuring database URLs may be customized within the env.py
    # file.
    # sqlalchemy.url = driver://user:pass@localhost/dbname
    
    • 이제 env.py에서 settings.sync_database_url직접 사용하도록 변경
      • Alembic은 이제 alembic.ini에 설정된 sqlalchemy.url 값을 사용하지 않는다
    • 주석 처리하거나 제거



마이그레이션 스크립트 적용하기

  • Alembic에게 마이그레이션 파일 요청

    alembic revision --autogenerate -m "Create Country Model table"
    
    • -autogenerate
      • env.pytarget_metadata를 보고 현재 데이터베이스 스키마와 비교한다
    • -m "..."
      • 마이그레이션에 대한 설명을 추가
    • 결과
      • alembic/versions/ 폴더에 테이블 생성 코드가 담긴 파이썬 파일이 생성


  • 마이그레이션 실행

    alembic upgrade head
    
    • upgrade head
      • 가장 최근에 생성된 마이그레이션 스크립트까지 모두 실행



스키마 변경 시 마이그레이션

  • 모델 변경 후 아래 과정을 적용해야 한다
  • 마이그레이션 스크립트 생성

    alembic revision --autogenerate -m "Add population to Country model"
    
    • --autogenerate로 DB의 스키마와 SQLAlchemy 모델을 비교하여 변경 분석
    • 결과로 alembic/versions/ 디렉토리에 새로운 버전 파일이 생성된다.


  • 마이그레이션 적용하기

    alembic upgrade head
    
    • 아직 데이터베이스에 적용되지 않은 마이그레이션 스크립트 중 가장 최신 버전(head)까지 모두 실행

app/domains/country/router.py (수정)

from fastapi import APIRouter
from fastapi.params import Depends
from sqlalchemy.ext.asyncio import AsyncSession
from typing import List
from app.core.database import get_db
from app.core.deps import get_country_service
from app.common.schemas.base import BaseResponse
from app.domains.country.services import CountryService
from app.domains.country.schemas import CountryListItem, CountryResponse, CountryUpdate, CountryCreate

router = APIRouter(prefix="/countries")

@router.get("/", response_model=BaseResponse[List[CountryListItem]])
async def get_countries(
    country_service: CountryService = Depends(get_country_service),
    db: AsyncSession = Depends(get_db)
):
    countries = await country_service.get_all(db=db)
    response = [CountryListItem.model_validate(country) for country in countries]
    return {
        "code": "200",
        "message": "목록 조회 성공",
        "data": response
    }


@router.get("/{country_id}", response_model=BaseResponse[CountryResponse])
async def get_country(
    country_id: int,
    country_service: CountryService = Depends(get_country_service),
    db: AsyncSession = Depends(get_db)
):
    country = await country_service.get_by_id(db, country_id)
    resource = CountryResponse.model_validate(country)
    return {
        "code": "200",
        "message": f"{country_id}번 국가 조회 성공",
        "data": resource
    }


@router.post("/")
async def create_country(
    country_create_dto: CountryCreate,
    country_service: CountryService = Depends(get_country_service),
    db: AsyncSession = Depends(get_db)
):
    await country_service.create(db, obj_in=country_create_dto)
    return {
        "code": "201",
        "message": "새로운 국가 생성 성공",
        "data": None
    }


@router.put("/{country_id}")
async def update_country(
    country_id: int,
    country_update_dto: CountryUpdate,
    country_service: CountryService = Depends(get_country_service),
    db: AsyncSession = Depends(get_db)
):
    country = await country_service.get_by_id(db, country_id)
    await country_service.update(db, db_obj=country, obj_in=country_update_dto)
    return {
        "code": "204",
        "message": f"{country_id}번 국가 수정 성공",
        "data": None
    }


@router.delete("/{country_id}")
async def delete_country(
    country_id: int,
    country_service: CountryService = Depends(get_country_service),
    db: AsyncSession = Depends(get_db)
):
    country = await country_service.get_by_id(db, country_id)
    await country_service.remove(db, db_obj=country)
    return {
        "code": "204",
        "message": f"{country_id}번 국가 제거 성공",
        "data": None
    }
2.4. 공통 응답 및 에러 처리

app/common/schemas/base.py (수정)

from pydantic import BaseModel, Field
from typing import Generic, TypeVar, Optional
from datetime import datetime

T = TypeVar("T")

class BaseResponse(BaseModel, Generic[T]):
    code: str = Field("200", description="응답 코드")
    message: str = Field("OK", description="응답 메시지")
    data: Optional[T] = Field(None, description="응답 데이터")

class ErrorResponse(BaseModel):
    timestamp: datetime = Field(default_factory=datetime.now, description="에러 발생 일시")
    code: str = Field("500", description="에러 코드")
    message: str = Field("Internal Server Error", description="에러 메시지")
    data: None = Field(None, description="응답 데이터")
    details: Optional[Dict[str, Any]] = Field(None, description="에러 상세 정보")
  • BaseResponse

    T = TypeVar("T")
    
    class BaseResponse(BaseModel, Generic[T]):
        code: str = Field("200", description="응답 코드")
        message: str = Field("Success", description="응답 메시지")
        data: Optional[T] = Field(None, description="응답 데이터")
    
    • 정상적인 API 응답을 위한 기본 스키마
    • data

      • 제네릭 타입 T로 정의되어 있어, 다양한 데이터 모델을 담을 수 있다

    • Field( ) 함수 사용

      • Field() 함수로 메타데이터 추가 (Swagger API)
      • 조건을 설정하여 더 구체적으로 유효성 검사 가능
        • Pydantic BaseModel로 타입 검사는 기본적으로 수행한다
      • Field(...)로 기본값 없이 필수 필드로 지정


  • ErrorResponse

    class ErrorResponse(BaseModel):
        timestamp: datetime = Field(default_factory=datetime.now, description="에러 발생 일시")
        code: str = Field("500", description="에러 코드")
        message: str = Field("Internal Server Error", description="에러 메시지")
        data: None = Field(None, description="응답 데이터")
        details: Optional[Dict[str, Any]] = Field(None, description="에러 상세 정보")
    
    • 에러 발생 시 API 응답을 위한 스키마
    • data 필드는 항상 None
    • default_factory
      • Field 함수에서 사용
      • 인수가 없는 함수를 받아서 모델이 생성될 때마다 해당 함수를 호출하여 기본값 생성



app/common/schemas/pagination.py

from pydantic import BaseModel, Field
from typing import List, TypeVar, Generic

T = TypeVar("T")

class PaginationMeta(BaseModel):
    total_elements: int = Field(..., description="전체 항목 수")
    page: int = Field(..., description="현재 페이지 번호")
    page_size: int = Field(..., description="페이지 당 항목 수")
    has_next: bool = Field(..., description="다음 페이지 존재 여부")
    has_prev: bool = Field(..., description="이전 페이지 존재 여부")

class PageResponse(BaseModel, Generic[T]):
    data: List[T] = Field(..., description="응답 데이터")
    meta: PaginationMeta = Field(..., description="페이지 메타데이터")
  • PaginationMeta

    class PaginationMeta(BaseModel):
        total_elements: int = Field(..., description="전체 항목 수")
        page: int = Field(..., description="현재 페이지 번호")
        page_size: int = Field(..., description="페이지 당 항목 수")
        has_next: bool = Field(..., description="다음 페이지 존재 여부")
        has_prev: bool = Field(..., description="이전 페이지 존재 여부")
    
    • 페이지네이션 메타데이터를 위한 스키마
    • data
      • 제네릭 타입 T로 정의되어 있어, 다양한 데이터 모델을 담을 수 있다


  • PageResponse

    T = TypeVar("T")
    
    class PageResponse(BaseModel, Generic[T]):
        data: List[T] = Field(..., description="응답 데이터")
        meta: PaginationMeta = Field(..., description="페이지 메타데이터")
    
    • 페이지네이션 응답 스키마
    • data
      • 제네릭 타입 T로 정의되어 있어, 다양한 데이터 모델을 담을 수 있다



app/domains/country/schemas.py (수정)

from datetime import datetime

from pydantic import BaseModel, ConfigDict, Field
from typing import Optional


class CountryBase(BaseModel):
    country_code: str = Field(..., description="국가 코드", examples=["KR", "US"])

    model_config = ConfigDict(from_attributes=True)


class CountryCreate(CountryBase):
    name: str = Field(..., description="국가 영문명", examples=["Australia"])


class CountryUpdate(CountryCreate):
    name: Optional[str] = Field(None, description="국가 영문명", examples=["Australia"])
    country_code: Optional[str] = Field(None, description="국가 코드", examples=["KR", "US"])


class CountryListItem(CountryBase):
    id: int = Field(..., description="ID", examples=[4])


class CountryResponse(CountryCreate):
    id: int = Field(..., description="ID", examples=[4])
    created_at: datetime = Field(..., description="생성 일시", examples=[datetime.now()])
    updated_at: datetime = Field(..., description="수정 일시", examples=[datetime.now()])
  • Swagger UI 메타데이터 추가를 위해 Field() 함수를 활용

app/core/exceptions.py

from __future__ import annotations
from fastapi import status
from dataclasses import dataclass, field
from typing import Any, Dict
from enum import Enum

@dataclass(frozen=True)
class ErrorCode:
    code: str
    http_status: int
    message: str

class CustomException(Exception):
    def __init__(self, error_type: ErrorType, message: str = None, details: Dict[str, Any] = None):
        self.error_code = error_type.value
        self.details = details if details else {}
        self.message = message or self.error_code.message
        super().__init__(self.message)

class ErrorType(Enum):
    INTERNAL_SERVER_ERROR = ErrorCode(
        code="500", http_status=status.HTTP_500_INTERNAL_SERVER_ERROR, message="서버에 오류가 발생했습니다"
    )
    NOT_FOUND = ErrorCode(
        code="404", http_status=status.HTTP_404_NOT_FOUND, message="요청하신 리소스를 찾을 수 없습니다"
    )
    BAD_REQUEST = ErrorCode(
        code="400", http_status=status.HTTP_400_BAD_REQUEST, message="Bad Request"
    )
    VALIDATION_ERROR = ErrorCode(
        code="422", http_status=status.HTTP_422_UNPROCESSABLE_ENTITY, message="Validation Error"
    )

class ResourceNotFoundException(CustomException):
    def __init__(self, *, resource_type: str, resource_id: Any):
        message = f"ID {resource_id}{resource_type}을 찾을 수 없습니다."
        details = {"resource_id": resource_id, "resource_type": resource_type}
        super().__init__(ErrorType.NOT_FOUND, message=message, details=details)
  • from __future__ import annotations

    • Python 3.7+의 기능 : from __future__ import annotations
    • 타입 힌트의 순환 참조 문제 해결
    • 모든 타입 힌트를 문자열로 저장
      • 런타임 성능 향상
    • 지연 평가 (lazy evaluation)
      • 타입 힌트가 실제로 필요할 때만 평가
    • import 시간 단축
      • 타입 체킹용 모듈들을 즉시 로드하지 않는다


  • ErrorCode

    @dataclass(frozen=True)
    class ErrorCode:
        code: str
        http_status: int
        message: str
    
    • 표준화된 에러 코드를 정의하는 클래스
    • code : 클라이언트에 노출될 고유 에러 코드
    • http_status : HTTP 상태 코드
    • message : 기본 에러 메시지
    • @dataclass
      • Python 3.7부터 도입된 데코레이터
      • 데이터를 캡슐화하고 속성을 가진 객체를 쉽게 만들 때 사용
      • 데이터를 저장하는 클래스를 쉽게 만들 수 있도록 도와준다
        • __init__, __repr__, __eq__ 등 자동 생성
    • frozen=True
      • 불변(immutable) 객체를 만드는 기능
      • 생성 후 필드 값 변경 불가능


  • CustomException 클래스

    class CustomException(Exception):
        def __init__(self, error_type: ErrorType, message: str = None, details: Dict[str, Any] = None):
            self.error_code = error_type.value
            self.details = details if details else {}
            self.message = message or self.error_code.message
            super().__init__(self.message)
    
    class CustomException(Exception):
        def __init__(self, error_code: ErrorCode, message: str = None):
            super().__init__(message or error_code.message)
            self.error_code = error_code
            if message:
                self.error_code.message = message
    
    • 커스텀 에러 처리를 위한 예외 클래스
      • FastAPI가 제공하는 기본 HTTPException을 활용해도 되지만 상세한 설정이 가능

    • error_type Enum을 전달하여 예외 발생 시 상세 정보를 전달


    • CustomException 객체 생성 시 특정 메세지나 Detail을 포함한다면, 자체 필드로 그 값을 받는다

      def __init__(self, error_type: ErrorType, message: str = None, details: Dict[str, Any] = None):
              self.error_code = error_type.value
          self.details = details if details else {}
          self.message = message or self.error_code.message
      

    • details 예시
      • "details": {"field": "email", "reason": "invalid_format"}
      • "details": {"resource": "/users/123", "required_permission": "admin"}
      • "details": {"external_service": "payment_gateway", "error_code": "PG-503"}


  • ErrorType 클래스

    class ErrorType(Enum):
        INTERNAL_SERVER_ERROR = ErrorCode(
            code="500", http_status=status.HTTP_500_INTERNAL_SERVER_ERROR, message="서버에 오류가 발생했습니다"
        )
        NOT_FOUND = ErrorCode(
            code="404", http_status=status.HTTP_404_NOT_FOUND, message="요청하신 리소스를 찾을 수 없습니다"
        )
        BAD_REQUEST = ErrorCode(
            code="400", http_status=status.HTTP_400_BAD_REQUEST, message="Bad Request"
        )
        VALIDATION_ERROR = ErrorCode(
            code="422", http_status=status.HTTP_422_UNPROCESSABLE_ENTITY, message="Validation Error"
        )
    
    • Enum을 활용하여 실수하지 않도록 설정


  • ResourceNotFoundException 클래스

    class ResourceNotFoundException(CustomException):
        def __init__(self, *, resource_type: str, resource_id: Any):
            message = f"ID {resource_id}{resource_type}을 찾을 수 없습니다."
            details = {"resource_id": resource_id, "resource_type": resource_type}
            super().__init__(ErrorType.NOT_FOUND, message=message, details=details)
    
    • 쉽게 에러 처리를 할 수 있도록 에러 클래스를 미리 만들어 놓고 처리한다.



app/domains/country/exceptions.py

from app.core.exceptions import CustomException, ErrorType
from typing import Any

class CountryNotFoundException(CustomException):
    def __init__(self, *, resource_id: Any):
        message = f"ID {resource_id}번 국가를 찾을 수 없습니다."
        details = {"resource_id": resource_id, "resource_type": "Country"}
        super().__init__(ErrorType.NOT_FOUND, message=message, details=details)
  • Country 도메인 특화 에러 모음



app/core/exception_handlers.py

import logging
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
from fastapi.exceptions import RequestValidationError
from app.core.exceptions import CustomException, ErrorType
from app.common.schemas.base import ErrorResponse

logger = logging.getLogger(__name__)

def add_exception_handler(app: FastAPI):

    @app.exception_handler(CustomException)
    async def custom_exception_handler(request: Request, exc: CustomException):
        logger.error(f"CustomException: {exc.message} - Code: {exc.error_code.code}")

        response = ErrorResponse(
            code=exc.error_code.code,
            message=exc.message,
            details = exc.details if exc.details else None
        )

        return JSONResponse(
            status_code=exc.error_code.http_status,
            content=response.model_dump(mode="json")
        )

    @app.exception_handler(RequestValidationError)
    async def validation_exception_handler(request: Request, exc: RequestValidationError):
        logger.error(f"Validation error: {exc.errors()}")

        error_messages = []
        for error in exc.errors():
            field = " -> ".join(str(x) for x in error["loc"][1:])
            message = error["msg"]
            error_messages.append(f"{field}: {message}")

        error_code = ErrorType.VALIDATION_ERROR.value
        response = ErrorResponse(
            message="; ".join(error_messages),
            code=error_code.code
        )

        return JSONResponse(
            status_code=error_code.http_status,
            content=response.model_dump(mode="json")
        )

    @app.exception_handler(404)
    async def not_found_handler(request: Request, exc):

        error_code = ErrorType.NOT_FOUND.value
        response = ErrorResponse(
            message=error_code.message,
            code=error_code.code
        )

        return JSONResponse(
            status_code=error_code.http_status,
            content=response.model_dump(mode="json")
        )

    @app.exception_handler(500)
    async def internal_error_handler(request: Request, exc):
        logger.error(f"Internal server error: {str(exc)}")

        error_code = ErrorType.INTERNAL_SERVER_ERROR.value
        response = ErrorResponse(
            message=error_code.message,
            code=error_code.code
        )

        return JSONResponse(
            status_code=error_code.http_status,
            content=response.model_dump(mode="json")
        )
  • JSONResponse의 필요성

    • FastAPI기본적으로 Pydantic 모델이나 Python 딕셔너리를 반환하면 자동으로 JSON 응답으로 변환한다
    • 하지만 예외가 발생한 상황개발자가 직접 HTTP 상태 코드와 응답 바디를 명시적으로 제어해야 한다


  • model_dump( )

    • Pydantic V2에서 Pydantic 모델 객체를 파이썬 딕셔너리로 변환하는 메서드
    • 예외 발생 상황이므로 Pydantic 모델 객체를 HTTP 응답 본문으로 사용하기 위해 JSON 형식으로 변환 가능한 딕셔너리로 만들어야 한다
    • timestamp의 경우 직렬화 문제가 있을 수 있기 때문에 model_dump(mode=”json”)으로 설정


  • custom_exception_handler( )

    • CustomException 클래스의 자체 messagedetails 필드
      • 이를 통해 details 필드를 통해 동적으로 예외 상황의 상세 정보를 추가할 수 있다
      • 유효성 검사 실패 시
        • {"field": "email", "reason": "invalid_format"}
      • 외부 API 호출 실패 시 해당 서비스의 에러 코드
        • {"external_api": "payment_gateway", "error_code": "PG-400"}
      • 사용자를 위한 것이 아닌 오류 분석 시스템 구축 시 그것을 위한 메시지


  • validation_exception_handler( )

    • FastAPI가 요청을 받으면, Pydantic 모델의 규칙(ex : max_length, gt)에 따라 유효성 검사 수행

    • 유효성 검사에 실패 시 RequestValidationError라는 특별한 예외 객체가 자동으로 발생

      • 예외 객체에는 오류 리스트가 담겨 있다.

        [
        {
            "loc": ["body", "name"],
            "msg": "Input should be a valid string",
            "type": "string_type"
        },
        {
            "loc": ["body", "price"],
            "msg": "Input should be greater than 0",
            "type": "greater_than"
        }
        ]
        
    • validation_exception_handler 함수는 이 리스트를 순회

      • loc(오류 위치)와 msg(오류 메시지)를 뽑아내서 가공 후 클라이언트에 전달
      • name: Input should be a valid string; price: Input should be greater than 0



app/main.py

from app.core.exception_handlers import add_exception_handler
from app.core.exceptions import CustomException, ErrorType

add_exception_handler(app)

@app.get("/test/error")
async def test_error():
    raise CustomException(ErrorType.NOT_FOUND)
  • add_exception_handler(app)로 전역 에러 처리를 등록하면 전역 에러 처리가 가능해진다.


  • 예외 처리 테스트 API

    @app.get("/test/error")
    async def test_error():
        raise CustomException(ErrorType.NOT_FOUND)
    

app/common/crud/base.py (수정)

from typing import Type, TypeVar, Any, Optional, List
from sqlalchemy import select, func
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import DeclarativeBase

ModelType = TypeVar("ModelType", bound=DeclarativeBase)

class CRUDBase:
    def __init__(self, model: Type[ModelType]):
        self.model = model

    async def get(self, db: AsyncSession, obj_id: Any) -> Optional[ModelType]:
        return await db.get(self.model, obj_id)

    async def get_multi(self, db: AsyncSession, *, skip: int = 0, limit: int = 100) -> List[ModelType]:
        stmt = select(self.model).offset(skip).limit(limit)
        result = await db.execute(stmt)
        return result.scalars().all()

    async def get_all(self, db: AsyncSession) -> List[ModelType]:
        stmt = select(self.model)
        results = await db.execute(stmt)
        return results.scalars().all()

    async def count(self, db: AsyncSession) -> int:
        stmt = select(func.count()).select_from(self.model)
        result = await db.execute(stmt)
        return result.scalar_one()

    async def create(self, db: AsyncSession, *, obj_in: Any) -> ModelType:
        db_obj = self.model(**obj_in.model_dump())
        db.add(db_obj)
        return db_obj

    async def update(self, db: AsyncSession, db_obj: ModelType, obj_in: Any) -> ModelType:
        update_data = obj_in.model_dump(exclude_unset=True)
        for key, value in update_data.items():
            setattr(db_obj, key, value)
        db.add(db_obj)
        return  db_obj

    async def remove(self, db: AsyncSession, *, db_obj: ModelType) -> ModelType:
        await db.delete(db_obj)
        return db_obj
  • 불필요한 메서드 제거
    • get_by_field, get_by_fields, get_multi_by_fields, count_by_fields




app/common/services/base.py (수정)

from typing import TypeVar, Any, Optional, List
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.exc import IntegrityError
from app.common.crud.base import CRUDBase, ModelType
from app.core.exceptions import ResourceNotFoundException, ResourceConflictException

CRUDType = TypeVar("CRUDType", bound=CRUDBase)

class BaseService:
    def __init__(self, crud: CRUDType):
        self.crud : CRUDType = crud

    async def get_by_id(self, db: AsyncSession, obj_id: Any) -> ModelType:
        obj = await self.crud.get(db, obj_id)
        if not obj:
            raise ResourceNotFoundException(
                resource_type=self.crud.model.__name__,
                resource_id=obj_id
            )
        return obj

    async def get_multi(self, db: AsyncSession, *, skip: int = 0, limit: int = 100) -> List[ModelType]:
        return await self.crud.get_multi(db, skip=skip, limit=limit)

    async def get_all(self, db: AsyncSession) -> List[ModelType]:
        return await self.crud.get_all(db)

    async def count(self, db: AsyncSession) -> int:
        return await self.crud.count(db)

    async def create(self, db: AsyncSession, *, obj_in: Any) -> ModelType:
        try:
            db_obj = await self.crud.create(db, obj_in=obj_in)
            await db.commit()
            await db.refresh(db_obj)
            return db_obj
        except IntegrityError:
            await db.rollback()
            raise ResourceConflictException(resource_type=self.crud.model.__name__, operation="생성")

    async def update(self, db: AsyncSession, db_obj: ModelType, obj_in: Any) -> ModelType:
        try:
            updated_obj = await self.crud.update(db, db_obj=db_obj, obj_in=obj_in)
            await db.commit()
            await db.refresh(updated_obj)
            return updated_obj
        except IntegrityError:
            await db.rollback()
            raise ResourceConflictException(resource_type=self.crud.model.__name__, operation="수정")

    async def remove(self, db: AsyncSession, *, db_obj: ModelType) -> ModelType:
        removed_obj = await self.crud.remove(db, db_obj=db_obj)
        await db.commit()
        return removed_obj
  • 제네릭 타입 오류 수정

    CRUDType = TypeVar("CRUDType", bound=CRUDBase)
    
    class BaseService:
        def __init__(self, crud: CRUDType):
            self.crud : CRUDType = crud
    
    • 이렇게 해야지 IDE가 정확한 타입을 인식할 수 있다


  • 불필요한 메서드 제거

    • get_by_field, get_by_fields, get_multi_by_fields, count_by_fields


  • get_by_id 함수

    • 수정 전

      async def get_by_id(self, db: AsyncSession, obj_id: Any) -> Optional[ModelType]:
          return await self.crud.get(db, obj_id)
      

    • 수정 후

      async def get_by_id(self, db: AsyncSession, obj_id: Any) -> ModelType:
          obj = await self.crud.get(db, obj_id)
          if not obj:
              raise ResourceNotFoundException(
                  resource_type=self.crud.model.__name__,
                  resource_id=obj_id
              )
          return obj
      
      • 반환 타입에서 Optional을 제거하여 객체를 반드시 반환하거나 예외를 발생시키도록 변경
      • ResourceNotFoundException 에러 처리
      • self.crud.model.__name__
        • CRUDBase 인스턴스의 model 속성으로 전달받은 모델 클래스의 클래스 이름을 문자열로 받는다




app/core/exceptions.py (수정)

class ErrorType(Enum):
    INTERNAL_SERVER_ERROR = ErrorCode(
        code="500", http_status=status.HTTP_500_INTERNAL_SERVER_ERROR, message="서버에 오류가 발생했습니다"
    )
    NOT_FOUND = ErrorCode(
        code="404", http_status=status.HTTP_404_NOT_FOUND, message="요청하신 리소스를 찾을 수 없습니다"
    )
    BAD_REQUEST = ErrorCode(
        code="400", http_status=status.HTTP_400_BAD_REQUEST, message="Bad Request"
    )
    VALIDATION_ERROR = ErrorCode(
        code="422", http_status=status.HTTP_422_UNPROCESSABLE_ENTITY, message="Validation Error"
    )
    CONFLICT = ErrorCode(
        code="409", http_status=status.HTTP_409_CONFLICT, message="리소스 충돌이 발생했습니다"
    )
    

class ResourceConflictException(CustomException):
    def __init__(self, *, resource_type: str, operation: str = "처리"):
        message = f"{resource_type} {operation} 중 중복 데이터가 발견되었습니다."
        details = {"resource_type": resource_type, "operation": operation}
        super().__init__(ErrorType.CONFLICT, message=message, details=details)
  • ErrorType에 CONFLICT 타입 추가

    CONFLICT = ErrorCode(
        code="409", http_status=status.HTTP_409_CONFLICT, message="리소스 충돌이 발생했습니다"
    )
    

  • ResourceConflictException 공용 예외 클래스 추가

    class ResourceConflictException(CustomException):
        def __init__(self, *, resource_type: str, operation: str = "처리"):
            message = f"{resource_type} {operation} 중 중복 데이터가 발견되었습니다."
            details = {"resource_type": resource_type, "operation": operation}
            super().__init__(ErrorType.CONFLICT, message=message, details=details)
    

app/domains/country/crud.py (수정)

from sqlalchemy.ext.asyncio import AsyncSession
from typing import Optional
from sqlalchemy import select
from app.domains.country.models import Country
from app.common.crud.base import CRUDBase

class CRUDCountry(CRUDBase):
    def __init__(self, model: type[Country]):
        super().__init__(model)

    async def get_by_country_code(self, db: AsyncSession, country_code: str) -> Optional[Country]:
        stmt = select(self.model).where(self.model.country_code == country_code)
        result = await db.execute(stmt)
        return result.scalars().one_or_none()

country_crud = CRUDCountry(Country)
  • get_by_country_code 함수 생성

    async def get_by_country_code(self, db: AsyncSession, country_code: str) -> Optional[Country]:
        stmt = select(self.model).where(self.model.country_code == country_code)
        result = await db.execute(stmt)
        return result.scalars().one_or_none()
    
    • Optional예외 처리가 가능하게 반환한다.




app/domains/country/exceptions.py (수정)

from app.core.exceptions import CustomException, ErrorType
from typing import Any

class CountryNotFoundException(CustomException):
    def __init__(self, *, field_name: str, value: Any):
        message = f"{field_name} '{value}'에 해당하는 국가를 찾을 수 없습니다."
        details = {"field_name": field_name, "value": value, "resource_type": "Country"}
        super().__init__(ErrorType.NOT_FOUND, message=message, details=details)

class CountryConflictException(CustomException):
    def __init__(self, *, operation: str = "처리"):
        message = f"Country {operation} 중 중복 데이터가 발견되었습니다."
        details = {"resource_type": "Country", "operation": operation}
        super().__init__(ErrorType.CONFLICT, message=message, details=details)
  • CountryConflictException 예외 클래스 추가

    class CountryConflictException(CustomException):
        def __init__(self, *, operation: str = "처리"):
            message = f"Country {operation} 중 중복 데이터가 발견되었습니다."
            details = {"resource_type": "Country", "operation": operation}
            super().__init__(ErrorType.CONFLICT, message=message, details=details)
    




app/domains/country/services.py (수정)

from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.exc import IntegrityError
from app.common.services.base import BaseService
from app.domains.country.crud import CRUDCountry
from app.domains.country.models import Country
from app.domains.country.schemas import CountryCreate, CountryUpdate
from app.domains.country.exceptions import CountryNotFoundException, CountryConflictException

class CountryService(BaseService):
    def __init__(self, crud: CRUDCountry):
        super().__init__(crud)

    async def get_by_country_code(self, db: AsyncSession, country_code: str) -> Country:
        country = await self.crud.get_by_country_code(db, country_code)
        if not country:
            raise CountryNotFoundException(field_name="국가 코드", value=country_code)
        return country

    async def create_country(self, db: AsyncSession, *, country_create_dto: CountryCreate) -> Country:
        try:
            country = await self.crud.create(db, obj_in=country_create_dto)
            await db.commit()
            await db.refresh(country)
            return country
        except IntegrityError:
            await db.rollback()
            raise CountryConflictException(operation="생성")

    async def update(self, db: AsyncSession, db_country: Country, country_update_dto: CountryUpdate) -> Country:
        try:
            updated_country = await self.crud.update(db, db_obj=db_country, obj_in=country_update_dto)
            await db.commit()
            await db.refresh(updated_country)
            return updated_country
        except IntegrityError:
            await db.rollback()
            raise CountryConflictException(operation="수정")

  • get_by_country_code 함수 생성

    async def get_by_country_code(self, db: AsyncSession, country_code: str) -> Country:
        country = await self.crud.get_by_country_code(db, country_code)
        if not country:
            raise CountryNotFoundException(resource_id=country_code)
        return country
    
    • Country 도메인 특화 에러인 CountryNotFoundException 에러 처리


  • create_country 함수 생성

    async def create_country(self, db: AsyncSession, *, country_create_dto: CountryCreate) -> Country:
        try:
            country = await self.crud.create(db, obj_in=country_create_dto)
            await db.commit()
            await db.refresh(country)
            return country
        except IntegrityError:
            await db.rollback()
            raise CountryConflictException(operation="생성")
    
    • Country 도메인 특화 에러인 CountryConflictException 에러 처리


  • update 함수 생성

    async def update(self, db: AsyncSession, db_country: Country, country_update_dto: CountryUpdate) -> Country:
        try:
            updated_country = await self.crud.update(db, db_obj=db_country, obj_in=country_update_dto)
            await db.commit()
            await db.refresh(updated_country)
            return updated_country
        except IntegrityError:
            await db.rollback()
            raise CountryConflictException(operation="수정")
    
    • Country 도메인 특화 에러인 CountryConflictException 에러 처리
    • 이 경우 공용 서비스의 update 함수를 오버라이딩한다.
      • 단, router에서 키워드 인수로 호출 시 인수 이름 주의




app/domains/country/router.py (수정)

@router.get("/code/{country_code}", response_model=BaseResponse[CountryResponse])
async def get_country_by_country_code(
    country_code: str,
    country_service: CountryService = Depends(get_country_service),
    db: AsyncSession = Depends(get_db)
):
    country = await country_service.get_by_country_code(db, country_code)
    resource = CountryResponse.model_validate(country)
    return {
        "code": "200",
        "message": f"국가 코드 {country_code}의 국가 조회 성공",
        "data": resource
    }
    

@router.post("/")
async def create_country(
    country_create_dto: CountryCreate,
    country_service: CountryService = Depends(get_country_service),
    db: AsyncSession = Depends(get_db)
):
    await country_service.create(db, obj_in=country_create_dto)
    return {
        "code": "201",
        "message": "새로운 국가 생성 성공",
        "data": None
    }

@router.put("/{country_id}")
async def update_country(
    country_id: int,
    country_update_dto: CountryUpdate,
    country_service: CountryService = Depends(get_country_service),
    db: AsyncSession = Depends(get_db)
):
    country = await country_service.get_by_id(db, country_id)
    await country_service.update(db, country, country_update_dto)
    return {
        "code": "204",
        "message": f"{country_id}번 국가 수정 성공",
        "data": None
    }
  • 테스트 용 API 추가
  • update 함수 ⇒ Country 서비스 특화 update() 호출
  • create 함수 ⇒ 공용 서비스 create() 호출 (작동 테스트)




app/domains/country/models.py (수정)

from sqlalchemy import Column, Integer, String, DateTime
from sqlalchemy.orm import relationship
from app.common.models.base import BaseModel

class Country(BaseModel):
    __tablename__ = 'countries'

    country_code = Column(String(10), unique=True, nullable=False)
    name = Column(String(100), nullable=False)

    holidays = relationship(
        "Holiday",
        back_populates="country",
        cascade="all, delete-orphan",
    )
  • cascade 추가
    • 외래 키 제약으로 인해 삭제가 실패 대처
2.5. Celery 스케줄링 및 외부 API 호출 구현하기

app/external/api_client.py

import httpx

from app.core.config import settings
from app.domains.holiday.schemas import (
    AvailableCountriesApiResponse,
    PublicHolidaysApiResponse,
)

class HolidayAPIClient:
    def __init__(self):
        self.holidays_url = settings.nager_api_holidays_url
        self.countries_url = settings.nager_api_countries_url

    async def fetch_countries_data(self) -> list[AvailableCountriesApiResponse]:
        async with httpx.AsyncClient() as client:
            response = await client.get(self.countries_url)
            response.raise_for_status()
            return [AvailableCountriesApiResponse(**item) for item in response.json()]

    async def fetch_holidays_data(
        self, year: int, country_code: str
    ) -> list[PublicHolidaysApiResponse]:
        url = f"{self.holidays_url}/{year}/{country_code}"
        async with httpx.AsyncClient() as client:
            response = await client.get(url)
            response.raise_for_status()
            return [PublicHolidaysApiResponse(**item) for item in response.json()]

holiday_api_client = HolidayAPIClient()
  • httpx.AsyncClient( )

    async with httpx.AsyncClient() as client:
        response = await client.get(self.countries_url)
        response.raise_for_status()
        return [AvailableCountriesApiResponse(**item) for item in response.json()]
    
    • httpx

      • Python의 HTTP 클라이언트 라이브러리

    • httpx.AsyncClient()

      • 비동기 환경에서 HTTP 요청을 보낼 때 사용되는 객체
      • 비동기 HTTP 요청을 보낼 수 있는 클라이언트 세션을 생성
      • 이 함수는 async with과 함께 사용될 때 가장 효율적이다.
        • 시작될 때 네트워크 연결이 설정되고, 블록이 끝날 때 연결이 자동으로 해제되어 자원 누수 방지

    • AsyncClient().get(url)

      • AsyncClient 인스턴스의 메서드로, HTTP GET 요청 보내기
      • 요청이 완료될 때까지 실행을 일시 중단했다가 응답을 받으면 계속 진행
      • 다양한 인자
        • 인수로 **url**을 넣어서 GET 요청을 보낸다.
        • params
          • URL 쿼리 파라미터를 딕셔너리 형태로 전달
        • headers
          • 요청 헤더를 딕셔너리 형태로 전달
        • timeout
          • 요청 타임아웃 시간을 설정.
        • follow_redirects
          • 리다이렉션을 자동으로 따를지 여부를 설정

    • httpx.Response

      • 모든 client.method() 함수는 httpx.Response 객체를 반환
      • 이 객체에는 요청의 결과에 대한 모든 정보가 담겨 있다.
      • response.status_code
        • HTTP 상태 코드 (예 : 200)
      • response.text
        • 응답 본문을 문자열로 반환
      • response.json()
        • 응답 본문을 JSON으로 파싱하여 파이썬 딕셔너리나 리스트로 반환
      • response.headers
        • 응답 헤더를 딕셔너리로 반환


  • AsyncClient 인스턴스의 다른 함수들

    await client.post(url, data)
    await client.put(url, data)
    await client.delete(url)
    await client.patch(url, data)
    await client.head(url)
    


  • raise_for_status 함수

    async with httpx.AsyncClient() as client:
        response = await client.get(self.countries_url)
        response.raise_for_status()
        return [AvailableCountriesApiResponse(**item) for item in response.json()]
    
    • HTTP 요청의 상태 코드가 성공적이지 않을 때 예외를 발생시키는 편리한 메서드
    • 동작 방식

      • 응답 상태 코드(response.status_code)가 특정 에러인 경우 예외 발생
        • 4xx 번대 에러 또는 5xx 번대 에러일 경우
        • httpx.HTTPStatusError 예외 발생
      • 만약 상태 코드가 2xx이거나 3xx이라면 그냥 넘어간다

    • 목적

      • 명확한 에러 핸들링
      • 코드 간결화
        • 수동적인 상태 코드 체크 로직을 한 줄로 대체 가능

    • 이 자리에 올 수 있는 다른 함수

      • if not response.is_success
        • response 객체의 is_success 속성을 직접 확인하는 방법
      • try-except 블록
        • raise_for_status()는 예외를 발생시키므로, 이를 try-except 블록으로 감싸 특정 에러에 대해 분기 가능
        • 예시 : httpx.HTTPStatusError를 잡아서 로그 삽입



app/domains/holiday/schemas.py (수정)

...

class PublicHolidaysApiResponse(BaseModel):
    date: Date
    local_name: str = Field(..., alias="localName")
    name: str
    country_code: str = Field(..., alias="countryCode")
    is_global: bool = Field(..., alias="global")
    counties: Optional[List[str]] = None
    launch_year: Optional[int] = Field(None, alias="launchYear")
    types: Optional[List[str]] = None

class AvailableCountriesApiResponse(BaseModel):
    country_code: str = Field(..., alias="countryCode")
    name: str

class SyncDataRequest(BaseModel):
    country_code: str = Field(..., description="동기화할 국가 코드", examples=["KR"])
    year: int = Field(..., description="동기화할 연도", examples=[2024])
  • Pydantic 모델 필드 명명
    • 기본적으로 Pydantic 모델의 필드명은 JSON의 키와 일치해야 한다

    • 부득이한 경우, **Field(alias="…")**으로 맞춰 준다

    • 딕셔너리(JSON 데이터)를 언패킹을 통해 키워드 인자로 전환하여 Pydantic 모델로 값을 전달한다.

      [AvailableCountriesApiResponse(**item) for item in response.json()]
      
      1. response.json()
        • 외부 API의 JSON 응답을 파이썬의 List로 변환
        • 이 리스트의 각 요소는 딕셔너리이다
      2. for item in ...
        • List 순회
      3. AvailableCountriesApiResponse(**item)
        • 각 딕셔너리의 키-값 쌍을 AvailableCountriesApiResponse 클래스의 생성자로 전달 (언패킹)
        • 이 때 딕셔너리의 키(countryCode, name)가 Pydantic 필드 이름과 일치해야 값이 정상적으로 할당된다.
      4. Pydantic은 이 인자들을 받아서 타입 유효성 검사를 수행
        • 유효한 값들을 기반으로 새로운 Pydantic 모델 인스턴스를 생성




app/domains/holiday/services.py (수정)

# ...
from app.domains.country.crud import country_crud
from app.domains.holiday.schemas import PublicHolidaysApiResponse
from app.external.api_client import holiday_api_client
from app.core.exceptions import SyncException

class HolidayService(BaseService):
    def __init__(self, crud: CRUDHoliday, country_crud: CRUDCountry):
        super().__init__(crud)
        self.crud = crud
        self.country_crud = country_crud

    ....

    async def sync_holidays_by_country(
        self, db: AsyncSession, *, country_code: str, years: List[int]
    ):
        try:
            async with db.begin():
                total_sync_count = 0
                for year in years:
                    total_sync_count += await self._sync_holidays(
                        db, country_code=country_code, year=year
                    )
                return total_sync_count

        except Exception as e:
            logger.warning(f"🔴 {country_code}의 공휴일 동기화 중 오류 발생: {str(e)}")
            raise HolidaySyncException(
                country_code=country_code, year=year, error_message="진행 중 오류 발생"
            )

    async def sync_holidays_from_api(
        self, db: AsyncSession, *, country_code: str, year: int
    ):
        try:
            async with db.begin():
                return await self._sync_holidays(
                    db, country_code=country_code, year=year
                )

        except Exception as e:
            logger.warning(f"🔴 {country_code}{year}년 공휴일 동기화 중 오류 발생: {str(e)}")
            raise HolidaySyncException(
                country_code=country_code, year=year, error_message="진행 중 오류 발생"
            )

    async def _sync_holidays(
        self, db: AsyncSession, *, country_code: str, year: int
    ) -> int:
        country = await self.country_crud.get_by_country_code(db, country_code)
        if not country:
            raise CountryNotFoundException(field_name="국가 코드", value=country_code)

        api_data: List[
            PublicHolidaysApiResponse
        ] = await holiday_api_client.fetch_holidays_data(year, country_code)
        if not api_data:
            raise HolidaySyncException(
                country_code=country_code, year=year, error_message="공휴일이 없습니다"
            )

        api_data_set = set()
        unique_data = []
        for item in api_data:
            unique_key = (item.date, item.name, country.id)
            if unique_key not in api_data_set:
                api_data_set.add(unique_key)
                unique_data.append(item)

        holidays = [
            Holiday(
                date=item.date,
                local_name=item.local_name,
                name=item.name,
                country_id=country.id,
                is_global=item.is_global,
                counties=",".join(item.counties) if item.counties is not None else None,
                launch_year=item.launch_year,
                types=",".join(item.types) if item.types is not None else None,
                holiday_year=item.date.year,
            )
            for item in unique_data
        ]
        await self.crud.remove_all_by(db, country_id=country.id, year=year)
        await self.crud.create_all(db, holidays)
        return len(holidays)

  • 공휴일 외부 API 호출 로직 구현

    • api_client 파일holiday_api_client 인스턴스를 이용하여 공휴일 데이터 가져오기
      • JSON 파일을 딕셔너리 리스트로 받아오기
      • 딕셔너리 리스트를 DTO 리스트로 전환
    • DTO 리스트를 Holiday 모델 리스트로 전환
    • 해당 범위의 DB 데이터 싹 지우기
    • 새로운 Holiday 리스트 데이터 DB에 저장


  1. 국가코드로 국가 조회하기

    async def _sync_holidays(
        self, db: AsyncSession, *, country_code: str, year: int
    ) -> int:
        country = await self.country_crud.get_by_country_code(db, country_code)
        if not country:
            raise CountryNotFoundException(field_name="국가 코드", value=country_code)
    
    • 이용 가능한 국가를 미리 조회 저장했으므로, 해당 국가가 없다면 예외 발생


  2. API 클라이언트를 사용하여 외부 데이터 가져오기

    api_data: List[
        PublicHolidaysApiResponse
    ] = await holiday_api_client.fetch_holidays_data(year, country_code)
    if not api_data:
        raise HolidaySyncException(
            country_code=country_code, year=year, error_message="공휴일이 없습니다"
        )
    
    • holiday_api_clientfetch_holidays_data() 사용


  3. DTO를 Holiday 모델로 전환

    api_data_set = set()
    unique_data = []
    for item in api_data:
        unique_key = (item.date, item.name, country.id)
        if unique_key not in api_data_set:
            api_data_set.add(unique_key)
            unique_data.append(item)
    
    holidays = [
        Holiday(
            date=item.date,
            local_name=item.local_name,
            name=item.name,
            country_id=country.id,
            is_global=item.is_global,
            counties=",".join(item.counties) if item.counties is not None else None,
            launch_year=item.launch_year,
            types=",".join(item.types) if item.types is not None else None,
            holiday_year=item.date.year,
        )
        for item in unique_data
    ]
    
    • unique_data 리스트는 외부 API 자체에 중복되는 데이터가 있는 경우 처리하기 위함이다.
    • counties, types의 경우 DB에 CSV 형태의 문자열로 저장되도록 처음에 설정했으므로 그에 맞게 변형이 필요하다.
    • holiday_year가 빠져 None으로 들어가면 바로 에러가 발생하니 주의하자


  4. DB에 데이터를 저장

    await self.crud.remove_all_by(db, country_id=country.id, year=year)
    await self.crud.create_all(db, holidays)
    return len(holidays)
    
    • 먼저 해당 연도, 국가의 기존 데이터를 삭제하고 새로운 데이터를 저장
    • 트랜잭션 관리는 상위 함수에서 관리




app/domains/holiday/router.py (수정)

# ... (기존 임포트)
from fastapi import APIRouter, Query, BackgroundTasks

router = APIRouter(prefix="/holidays", tags=["holidays"])

# ... 

@router.post(
    "/sync",
    response_model=BaseResponse,
    description="특정 연도 및 국가 데이터 동기화 (Refresh)",
)
async def sync_holidays(
    background_task: BackgroundTasks,
    sync_request: SyncDataRequest,
    holiday_service: HolidayService = Depends(get_holiday_service),
    db: AsyncSession = Depends(get_db),
):
    background_task.add_task(
        holiday_service.sync_holidays_from_api,
        db=db,
        year=sync_request.year,
        country_code=sync_request.country_code,
    )
    return BaseResponse(code="204", message="공휴일 동기화 작업 백그라운드 시작", data=None)
  • 테스트용 API 엔드포인트 추가

    • Spring 스케줄러처럼 API 요청에 직접 로직을 실행하는 대신, 백그라운드 태스크로 넘겨서 웹 서버의 응답성을 유지한다
    • 클라이언트는 즉시 응답을 받고, 서버는 요청을 처리하는 메인 프로세스를 블로킹하지 않고 작업을 이어나갈 수 있다.


  • BackgroundTasks

    • FastAPI가 제공하는 클래스
    • FastAPI 엔드포인트 함수에 의존성으로 주입된다
    • HTTP 요청이 완료된 후 백그라운드에서 실행될 작업을 등록하는 기능


  • add_task( )

    • 백그라운드에서 실행할 함수를 등록
    • 첫 번째 인자는 호출할 함수
    • 이후의 모든 인자는 그 함수의 인수로 전달된다
    • 작업 실행 시점
      • add_task()로 등록된 모든 작업은 라우터 함수가 완료되어 클라이언트로 HTTP 응답이 반환된 직후에 실행된다
      • “나중에 실행해 줘"라고 등록만 해두는 것




app/domains/country/services.py (수정)

# ...
from app.domains.holiday.schemas import AvailableCountriesApiResponse
from app.external.api_client import holiday_api_client
from app.core.exceptions import SyncException

class CountryService(BaseService):
    def __init__(self, crud: CRUDCountry):
        super().__init__(crud)
        self.crud = crud

    ....

    async def sync_countries_from_api(self, db: AsyncSession) -> Dict[str, int]:
        try:
            async with db.begin():
                api_data: List[
                    AvailableCountriesApiResponse
                ] = await holiday_api_client.fetch_countries_data()
                if not api_data:
                    raise SyncException(
                        resource_type="Country", error_message="이용 가능한 국가가 없습니다"
                    )

                api_country_codes = [item.country_code for item in api_data]
                existing_countries = await self.crud.get_existing_all(
                    db, api_country_codes
                )
                existing_countries_dict = {
                    country.country_code: country for country in existing_countries
                }

                new_countries = []
                updated_count = 0
                for item in api_data:
                    if item.country_code in existing_countries_dict:
                        existing_country = existing_countries_dict[item.country_code]
                        if existing_country.name != item.name:
                            logger.warning(
                                f"🟢 국가 업데이트 : {item.country_code} - {existing_country.name} -> {item.name}"
                            )
                            existing_country.name = item.name
                            updated_count += 1
                    else:
                        new_countries.append(
                            Country(
                                country_code=item.country_code,
                                name=item.name,
                            )
                        )

                if new_countries:
                    await self.crud.create_all(db, new_countries)

                return {
                    "available_countries": len(api_data),
                    "new_countries": len(new_countries),
                    "existing_countries": len(existing_countries),
                }

        except Exception as e:
            logger.warning(f"🔴 국가 동기화 중 오류 발생: {str(e)}")
            raise SyncException(resource_type="Country", error_message="진행 중 오류 발생")

    async def get_countries_to_sync(
        self, db: AsyncSession, threshold_hours: int = 24
    ) -> List[Country]:
        threshold_time = datetime.now() - timedelta(hours=threshold_hours)
        return await self.crud.get_all_by_sync_time(db, threshold_time)

    async def update_sync_time_bulk(
        self, db: AsyncSession, db_countries: List[Country]
    ):
        try:
            async with db.begin():
                await self.crud.update_sync_time_bulk(db, db_countries=db_countries)

        except Exception as e:
            raise SyncException(
                resource_type="Country", error_message="sync_time 업데이트 오류"
            )
  • 국가 정보 외부 API 호출 로직 구현

    • api_client 파일holiday_api_client 인스턴스를 이용하여 국가 데이터 가져오기
      • JSON 파일을 딕셔너리 리스트로 받아오기
      • 딕셔너리 리스트를 DTO 리스트로 전환
    • DTO 리스트를 Country 모델 리스트로 전환
    • 가져온 목록과 기존 목록을 비교 작업
      • 수정이 필요한 경우만 수정
      • 신규 데이터만 따로 추출
    • 새로운 Country 리스트 데이터 DB에 저장


  1. API 클라이언트를 사용하여 외부 데이터 가져오기

    async with db.begin():
        api_data: List[
            AvailableCountriesApiResponse
        ] = await holiday_api_client.fetch_countries_data()
        if not api_data:
            raise SyncException(
                resource_type="Country", error_message="이용 가능한 국가가 없습니다"
            )
    
    • holiday_api_clientfetch_countries_data() 사용


  2. 기존 데이터와 신규 데이터 분리

    api_country_codes = [item.country_code for item in api_data]
    existing_countries = await self.crud.get_existing_all(
        db, api_country_codes
    )
    existing_countries_dict = {
        country.country_code: country for country in existing_countries
    }
    
    new_countries = []
    updated_count = 0
    for item in api_data:
        if item.country_code in existing_countries_dict:
            existing_country = existing_countries_dict[item.country_code]
            if existing_country.name != item.name:
                logger.warning(
                    f"🟢 국가 업데이트 : {item.country_code} - {existing_country.name} -> {item.name}"
                )
                existing_country.name = item.name
                updated_count += 1
        else:
            new_countries.append(
                Country(
                    country_code=item.country_code,
                    name=item.name,
                )
            )
    
    • 기존 데이터의 경우, 수정이 필요한 경우만 수정
    • 신규 데이터만 DTO를 Country 모델로 전환


  3. DB에 데이터를 저장

    if new_countries:
        await self.crud.create_all(db, new_countries)
    
    return {
        "available_countries": len(api_data),
        "new_countries": len(new_countries),
        "existing_countries": len(existing_countries),
    }
    
    • 새롭게 추출한 신규 데이터만 DB 삽입
    • 트랜잭션 관리는 서비스 Layer에서 진행




app/domains/holiday/router.py (수정)

@router.post(
    "/sync",
    response_model=BaseResponse,
    description="국가 데이터 동기화",
)
async def sync_holidays(
    background_task: BackgroundTasks,
    country_service: CountryService = Depends(get_country_service),
    db: AsyncSession = Depends(get_db),
):
    background_task.add_task(
        country_service.sync_countries_from_api,
        db=db,
    )
    return BaseResponse(code="204", message="국가 동기화 작업 백그라운드 시작", data=None)

app/core/celery_app.py

from celery import Celery
from celery.schedules import crontab

import app.tasks.holiday_tasks  # noqa: F401
from app.core.config import settings

celery_app = Celery("holiday_keeper")

celery_app.conf.broker_url = settings.celery_broker_url
celery_app.conf.result_backend = settings.celery_result_backend

celery_app.conf.timezone = settings.celery_timezone
celery_app.conf.enable_utc = settings.celery_enable_utc

celery_app.autodiscover_tasks(["app.tasks"])

beat_schedule = {
    "annual-sync": {
        "task": "app.tasks.holiday_tasks.sync_holidays",
        "schedule": crontab(hour=2, minute=0, day_of_month=2, month_of_year=1),
    }
}

if settings.celery_test_sync_enabled:
    beat_schedule["test-sync"] = {
        "task": "app.tasks.holiday_tasks.sync_holidays",
        "schedule": crontab(minute="5"),
    }

celery_app.conf.beat_schedule = beat_schedule

  • 환경 변수로 Celery 설정

    celery_app.conf.broker_url = settings.celery_broker_url
    celery_app.conf.result_backend = settings.celery_result_backend
    
    celery_app.conf.timezone = settings.celery_timezone
    celery_app.conf.enable_utc = settings.celery_enable_utc
    


  • 환경 변수로 Test Task 관리

    beat_schedule = {
        "annual-sync": {
            "task": "app.tasks.holiday_tasks.sync_holidays",
            "schedule": crontab(hour=2, minute=0, day_of_month=2, month_of_year=1),
        }
    }
    
    if settings.celery_test_sync_enabled:
        beat_schedule["test-sync"] = {
            "task": "app.tasks.holiday_tasks.sync_holidays",
            "schedule": crontab(minute="5"),
        }
    
    celery_app.conf.beat_schedule = beat_schedule
    
    • 환경 변수 설정에 따라 Task 삽입 여부를 결정한다.




app/tasks/holiday_tasks.py

from __future__ import annotations

import asyncio
import logging
from contextlib import asynccontextmanager
from datetime import datetime
from typing import List

from celery import shared_task

from app.core.database import AsyncSessionLocal, async_engine
from app.core.deps import get_country_service, get_holiday_service
from app.domains.country.models import Country

logger = logging.getLogger(__name__)

@asynccontextmanager
async def get_async_app_context():
    try:
        yield
    finally:
        await async_engine.dispose()

@shared_task(bind=True, name="app.tasks.holiday_tasks.init_holidays")
def init_holidays(self):
    logger.info("🟢 Celery Task - init_holidays 시작")

    async def _run():
        async with get_async_app_context():
            current_year = datetime.now().year
            years = [current_year - i for i in range(5)]
            await run_sync_task(years)
            logger.info(f"🟢 Celery Task - sync_holidays 완료 (연도: {years})")

    try:
        asyncio.run(_run())
    except Exception as e:
        logger.error(f"🔴 Celery Task - init_holidays 실행 중 오류 발생 : {e}")
        raise self.retry(exc=e, countdown=60, max_retries=3)

@shared_task(bind=True, name="app.tasks.holiday_tasks.sync_holidays")
def sync_holidays(self):
    logger.info("🟢 Celery Task - sync_holidays 시작")

    async def _run():
        async with get_async_app_context():
            current_year = datetime.now().year
            years = [current_year, current_year - 1]
            await run_sync_task(years)
            logger.info(f"🟢 Celery Task - sync_holidays 완료 (연도: {years})")

    try:
        asyncio.run(_run())
    except Exception as e:
        logger.error(f"🔴 Celery Task - sync_holidays 실행 중 오류 발생 : {e}")
        raise self.retry(exc=e, countdown=60, max_retries=3)

async def run_sync_task(years: List[int]):
    try:
        async with AsyncSessionLocal() as db:
            country_service = get_country_service()
            holiday_service = get_holiday_service()

            logger.info("🟢 국가 데이터 동기화 시작")
            result_dict = await country_service.sync_countries_from_api(db)
            logger.info(f"🟢 {result_dict['available_countries']}개 국가 데이터 동기화 완료")
            logger.warning(
                f"🟢 신규 : {result_dict['new_countries']} / 기존 : {result_dict['existing_countries']}"
            )

            async with db.begin():
                countries_to_sync = await country_service.get_countries_to_sync(
                    db, threshold_hours=24
                )
                logger.warning(f"🟢 동기화 대상 국가: {len(countries_to_sync)}개")

            error_country_codes: List[str] = []
            success_country_list: List[Country] = []
            total_sync_count = 0
            for country in countries_to_sync:
                logger.info(f"🟢 국가 코드 {country.country_code} : 공휴일 데이터 동기화 시작")
                try:
                    total_sync_count += await holiday_service.sync_holidays_by_country(
                        db, country_code=country.country_code, years=years
                    )
                    success_country_list.append(country)

                except Exception as e:
                    error_country_codes.append(country.country_code)
                    logger.warning(
                        f"🔴 국가 코드 {country.country_code} : 공휴일 동기화 중 오류 발생: {e}"
                    )
                    continue

            await country_service.update_sync_time_bulk(db, success_country_list)

            logger.warning(f"🟢 {years}년 공휴일 {total_sync_count}개 데이터 동기화 완료")
            if error_country_codes:
                logger.warning(f"🔴 실패 국가 목록 {error_country_codes}")
            else:
                logger.info(f"🟢 실패 국가 없음")

    except Exception as e:
        logger.error(f"🔴 동기화 작업 중 에러 발생 : {e}")
        raise





app/core/lifespan.py

import logging
from contextlib import asynccontextmanager

from fastapi import FastAPI

from app.core.config import settings

logger = logging.getLogger("uvicorn")

@asynccontextmanager
async def lifespan(app: FastAPI):
    logger.warning("🟢 FastAPI 애플리케이션 시작")

    try:
        # 초기 데이터 로딩
        if settings.is_prod:
            logger.info("🟢 초기 데이터 적재 시작")
            await trigger_init_holidays()
        else:
            logger.info("🟢 개발환경 : 초기 데이터 로딩 Skip")

    except Exception as e:
        logger.error(f"🔴 초기화 작업 중 오류 발생: {e}")
        raise

    yield

    logger.warning("🟢 FastAPI 애플리케이션 종료 ")

async def trigger_init_holidays():
    try:
        from app.core.celery_app import celery_app

        task = celery_app.send_task("app.tasks.holiday_tasks.init_holidays")
        logger.info(f"🟢 초기 데이터 로딩 작업 시작됨 (Task ID : {task.id})")

    except Exception as e:
        logger.error(f"🔴 초기 데이터 로딩 실패: {e}")

  • Celery 활용 초기 데이터 적재하기




app/main.py

app = FastAPI(
    title="Holiday Keeper API",
    description="공휴일 정보 관리 API",
    version="1.0.0",
    docs_url="/docs",
    redoc_url="/redoc",
    lifespan=lifespan,
)
  • uvicorn 실행 시 --lifespan on 설정

*uvicorn* app.main:app --reload --host $SERVER_HOST --port $SERVER_PORT --log-level debug --lifespan on