- Published on
- •👁️
FastAPI 마이그레이션 - 2. FastAPI 앱 구현 및 배포
- Authors

- Name
- River
이전 페이지로 이동 (1. 프로젝트 기반 구축)
FastAPI 마이그레이션
(Spring Boot에서 FastAPI로 전환)
2.1. 최소 FastAPI 앱 생성
app/main.py
from fastapi import FastAPI
app = FastAPI(
title="Holiday Keeper API",
description="공휴일 정보 관리 API",
version="1.0.0",
docs_url="/docs",
redoc_url="/redoc"
)
@app.get("/")
async def root():
return {"message": "Holiday Keeper API is running!"}
@app.get("/health")
async def health_check():
return {
"status": "healthy",
"service": "holiday-keeper-api",
"version": "1.0.0"
}
@app.get("/test")
async def test():
return {
"message": "테스트 성공",
"data": ["공휴일1", "공휴일2", "공휴일3"]
}
서버 실행 테스트
# 스크립트 사용 ./scripts/dev.sh # 직접 서버 실행 poetry run uvicorn app.main:app --reload --host 127.0.0.1 --port 8090- 확인
- http://localhost:8090 접속
- http://localhost:8090/docs 접속 (Swagger)
- 확인
2.2. 환경설정 및 기본 구조
Spring Boot (자동 구성)
application.yml
server: port: 8090 spring: datasource: url: jdbc:postgresql://localhost:5432/holiday_db- 설정 값 저장 (DB 연결 정보, 포트 등)
- Spring Boot의 핵심 중 하나인 Auto Configuration으로 자동 인식
WebConfig.java
@Configuration public class WebConfig implements WebMvcConfigurer { @Override public void addCorsMappings(CorsRegistry registry) { registry.addMapping("/**") .allowedOrigins("http://localhost:3000") .allowedMethods("*"); } @Override public void addInterceptors(InterceptorRegistry registry) { registry.addInterceptor(new LoggingInterceptor()); } }- 설정 값 적용, CORS 직접 설정, 인터셉터 등록
@Configuration으로 자동 스캔
LoggingInterceptor.java
@Component public class LoggingInterceptor implements HandlerInterceptor { @Override public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler) { log.info("Request: {} {}", request.getMethod(), request.getRequestURI()); return true; } }- 실제 로직 구현
FastAPI (수동 제어)
.env
SERVER_HOST="0.0.0.0" SERVER_PORT=8090 DATABASE_URL="postgresql://localhost:5432/holiday_db" CORS_ORIGINS="http://localhost:3000,http://localhost:5173"- 설정 값 저장 (DB 연결 정보, 포트 등)
config.py
from pydantic_settings import BaseSettings class Settings(BaseSettings): server_host: str = "0.0.0.0" server_port: int = 8090 database_url: str = "postgresql://..." cors_origins: str = "http://localhost:3000" model_config = {"env_file": ".env"} @property def cors_origins_list(self) -> List[str]: return self.cors_origins.split(",") settings = Settings()- .env 파일 읽기
- 설정 객체 생성
main.py
from fastapi import FastAPI from app.core.config import settings from app.core.middleware import add_middleware app = FastAPI(title="Holiday Keeper API") add_middleware(app) @app.get("/") async def root(): return {"message": "Hello"}- 수동으로 모든 설정 연결
middleware.py
from fastapi.middleware.cors import CORSMiddleware from app.core.config import settings def add_middleware(app: FastAPI): # CORS 설정 (WebConfig 역할) app.add_middleware( CORSMiddleware, allow_origins=settings.cors_origins_list, allow_credentials=True, allow_methods=["*"] ) # 로깅 미들웨어 (Interceptor 역할) @app.middleware("http") async def log_requests(request, call_next): start_time = time.time() response = await call_next(request) process_time = time.time() - start_time logger.info(f"{request.method} {request.url} - {process_time:.3f}s") return response- Spring Boot의 WebConfig + Interceptor 구현을 한 파일에 합친 형태
deps.py
from fastapi import Depends from app.core.config import settings async def get_db(): # DB 세션 의존성 주입용 pass def get_current_user(): # 인증 의존성 주입용 pass- Spring Boot의 @Autowired 의존성 주입을 수동으로 구현하는 곳
app/core/config.py
from functools import lru_cache
from typing import List
import os
from pathlib import Path
from pydantic import Field
from pydantic_settings import BaseSettings
environment = os.getenv("ENVIRONMENT", "dev")
env_file = f".env.{environment}"
_secrets_dir = Path("/run/secrets") if Path("/run/secrets").exists() else Path(os.getenv("SECRETS_PATH", "./secrets"))
class Settings(BaseSettings):
project_name: str = Field(default="Holiday-keeper-fastapi")
version: str = Field(default="1.0.0")
environment: str = environment
server_host: str = Field(default="0.0.0.0")
server_port: int = Field(default=8090)
cors_origins: str = Field(default="http://localhost:3000,http://localhost:5173")
log_level: str = Field(default="INFO")
redis_url: str = Field(default="redis://localhost:6379")
nager_api_base_url: str = Field(default="https://date.nager.at/api/v3")
database_url: str
sync_database_url: str
model_config = {
"env_file": env_file,
"env_file_encoding": "utf-8",
"case_sensitive": False,
"extra": "ignore",
"secrets_dir": _secrets_dir,
# "validate_default": True,
# "use_enum_values": True
}
@property
def cors_origins_list(self) -> List[str]:
return [origin.strip() for origin in self.cors_origins.split(",")]
@property
def is_dev(self) -> bool:
return self.environment.lower() == "dev"
@lru_cache()
def get_settings() -> Settings:
return Settings()
settings = get_settings()
@lru_cache와 싱글톤 패턴
@lru_cache()
def get_settings():
return Settings()
settings = get_settings()
Spring의 경우 기본적으로 싱글톤이지만 FastAPI는 그렇지 않다. 객체를 Spring의 Bean 처럼 활용하기 위함이다
get_settings( )
- 단순한 팩토리 함수 (객체 생성 반환)
@lru_cache
- 함수 결과 캐싱 ⇒ 싱글톤 구현
- 첫번째 호출 시에만 Settings() 객체 생성
- 함수를 여러 번 호출해도 캐싱한 결과 반환
- 여기까지가 Spring Boot의 싱글톤 빈을 수동으로 구현한 것이다.
settings = get_settings( )
- 전역 객체 생성
- 모듈 로드 시 한 번만 실행
- 다른 파일에서
from app.core.config import settings시 같은 객체를 사용한다. - 이렇게 편의 객체를 만들지 않으면 매번 다른 파일에서 함수를 호출해야 한다
- 전역 객체 생성
Pydantic의 Field
from pydantic import Field
class Settings(BaseSettings):
database_url: str
server_port: int = Field(default=8090, ge=1024, le=65535, description="서버 포트")
Field 함수를 사용하지 않는 경우
- 예시 :
database_url: str - 단순히 타입 힌트를 제공하며, 필드가 필수값임을 선언하는 것
BaseSettings는 이 필드 이름과 일치하는 것을 찾아 로딩한다.- 환경 변수 (OS 환경변수)
- .env 파일 값
- model_config에 설정된 secrets_dir 내의 동일한 파일 이름
- 예시 :
database_url.txt
- 예시 :
- 예시 :
Field 함수를 사용하는 경우
단순히 타입을 선언하는 것이 아닌 기본값을 지정하거나 유효성 검사 등 설정 가능
Spring Boot의 @Value + Bean Validation을 합친 기능
- 필드의 이름과 환경 변수의 이름이 같아야 동작한다
default- 기본값 설정
- default 우선 순위
- 환경변수 (OS 환경변수)
- .env 파일 값
- Field(default=값)
ge/le- 범위 제한 (greater equal, less equal)
description- API 문서에 표시될 설명
alias# .env MY_CUSTOM_PORT=8080 # Python server_port: int = Field(default=8090, alias="MY_CUSTOM_PORT")- 다른 이름으로 매핑해야 하는 경우 사용
BaseSettings
from pydantic_settings import BaseSettings
environment = os.getenv("ENVIRONMENT", "dev")
env_file = f".env.{environment}"
_secrets_dir = Path("/run/secrets") if Path("/run/secrets").exists() else Path(os.getenv("SECRETS_PATH", "./secrets"))
class Settings(BaseSettings):
name: str
version: str = Field(default="1.0.0")
model_config = {
"env_file": env_file, # 읽을 파일 경로
"env_file_encoding": "utf-8", # 파일 인코딩
"case_sensitive": False, # 대소문자 구분 안 함
"extra": "ignore", # 추가 필드 처리 방식
"validate_default": True, # 기본값도 검증할지
"use_enum_values": True, # Enum 값 사용
"secrets_dir": _secrets_dir, # Secret 폴더 경로 지정
}
BaseSettings
- Spring Boot의 @ConfigurationProperties
.env파일 자동 읽기- 환경변수 자동 매핑
- 타입 변환 (문자열 → 정수, 불린 등)
- 유효성 검증
프로파일 동적 설정
environment = os.getenv("ENVIRONMENT", "dev") env_file = f".env.{environment}" class Settings(BaseSettings): model_config = { "env_file": env_file, }- 환경 설정 파일
.env.dev.env.prod.env.test
- 환경 변수 설정에 따라 동적으로 사용하는 설정이다.
- 시크릿 설정은 시크릿 파일로 관리한다.
- 환경 설정 파일
_secrets_dir 전역 변수
- Docker Secrets 환경의 경로인
/run/secrets를 먼저 확인하고 없다면, 로컬 환경을 확인한다. - 전역 변수로 설정하는 이유는 Settings 클래스 정의 시 그 값을 활용하기 위함이다
- Docker Secrets 환경의 경로인
model_config = { }
BaseSettings가 어떻게 환경변수를 읽을지 설정하는 메타데이터 설정
Settings 클래스가 정의될 때, model_config 내부의 값들도 함께 정해진다
- Settings 클래스의 인스턴스가 생기기 전에는 클래스 내부 필드의 값을 참조할 수 없다
case_sensitive- Spring Boot는 kebab-case를 camelCase로 자동 변환
- 환경변수 이름의 대소문자를 구분하는지 여부
extra: "ignore"- Spring Boot의
ignoreUnknownFields = true - Settings 클래스에 없는 필드에 대한 처리 여부
extra: "ignore"- Settings 클래스에 없는 환경변수 무시
extra: "allow"일 때만settings.UNKNOWN_FIELD로 접근 가능
- Spring Boot의
validate_default- 기본값이 잘못된 경우를 검증하는지 여부
use_enum_values- Settings 클래스에서 다른 enum 클래스를 필드 타입으로 사용할지 여부
class Settings(BaseSettings): environment: Environment = Environment.DEV class Environment(Enum): DEV = "development" PROD = "production" use_enum_values=True⇒ "development" 값 사용use_enum_values=False⇒ Environment.DEV 객체 사용
- Settings 클래스에서 다른 enum 클래스를 필드 타입으로 사용할지 여부
secrets_dir- Secret 폴더의 경로를 지정하여 해당 경로의 값들도 읽을 수 있게 설정한다
Pydantic v1 (구버전)
class Settings(BaseSettings): name: str class Config: # 내부 클래스 env_file = ".env" case_sensitive = True
@property
class Settings:
cors_origins: str = "http://localhost:3000,http://localhost:5173"
@property
def cors_origins_list(self) -> List[str]:
return self.cors_origins.split(",")
# 사용법
settings.cors_origins_list
- 함수를 속성처럼 접근 가능하게 한다
- Java의 Getter와 같지만 호출 시 괄호가 필요 없다
- @{property이름}.setter
class Settings: def __init__(self): self._cors_data = "http://localhost:3000" @property def cors_origins(self): # property 이름 return self._cors_data @cors_origins.setter def cors_origins(self, value): self._cors_data = value- 이렇게 setter를 지정할 수 있다.
- 단, 메서드 이름도 property와 동일해야 한다.
- 환경변수 설정이므로 Settings 클래스에서는 쓸 일이 없다.
MySQL 연결 설정
# secrets/database_url
mysql+aiomysql://root:1234@127.0.0.1:3306/hk_db?charset=utf8mb4
# secrets/sync_database_url
mysql+pymysql://root:1234@127.0.0.1:3306/hk_db?charset=utf8mb4
# .env.dev
DATABASE_URL="mysql+aiomysql://root:1234@127.0.0.1:3306/hk_db?charset=utf8mb4"
SYNC_DATABASE_URL="mysql+pymysql://root:1234@127.0.0.1:3306/hk_db?charset=utf8mb4"
# config.py
class Settings:
database_url: str
sync_database_url: str
개발 / 운영 DB 설정
- 개발 환경
- 변수를 읽어들이는 순서에 따라 먼저
.env.dev에사 읽어들인다. .env.dev에는 값이 설정되어 있으므로 해당 값을 읽는다.
- 변수를 읽어들이는 순서에 따라 먼저
- 운영 환경
- 변수를 읽어들이는 순서에 따라 먼저
.env.prod에사 읽어들인다. .env.prod에는 값이 설정되어 있지 않다.- 따라서, 다음 순서인 시크릿 파일을 읽어서 값을 읽는다.
- 변수를 읽어들이는 순서에 따라 먼저
- 개발 환경
FastAPI DB 연결 설정
- diriver, username, password를 url 경로에 포함시켜 작성
charset=utf8mb4- 한글 처리 및 이모지까지 포함하는 설정
autocommit=true- connection 레벨에서, 즉 MySQL 드랑버 레벨에서 autocommit 처리
- FastAPI는 SQLAlchemy가 트랜잭션 관리하므로 불필요
use_unicode=1- MySQL 드라이버가 유니코드 문자열을 제대로 처리하도록 하는 설정
- 최신 Python MySQL 드라이버에서는 기본값
rewriteBatchedStatements=true- SQLAlchemy가 자체적으로 배치 처리
serverTimezone=Asia/Seoul- 대부분 자동 처리하지만 명시적으로 설정 가능
기존 Spring Boot 설정
spring: datasource: url: jdbc:mysql://127.0.0.1:3306/hk_db?characterEncoding=utf8&autoReconnect=true&serverTimezone=Asia/Seoul&rewriteBatchedStatements=true username: root password: 1234 driver-class-name: com.mysql.cj.jdbc.Driver
app/core/middleware.py
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.trustedhost import TrustedHostMiddleware
import time
import logging
from app.core.config import settings
logger = logging.getLogger(__name__)
def add_middleware(app: FastAPI):
app.add_middleware(
CORSMiddleware,
allow_origins=settings.cors_origins_list,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"]
)
if not settings.is_dev:
app.add_middleware(
TrustedHostMiddleware,
allowed_hosts=["yourdomain.com", "*.yourdomain.com"]
)
if settings.is_dev:
@app.middleware("http")
async def log_requests(request: Request, call_next):
start_time = time.time()
response = await call_next(request)
process_time = time.time() - start_time
logger.info(
f"{request.method} {request.url.path} - "
f"Status: {response.status_code} - "
f"Time: {process_time:.4f}s"
)
return response
중요 설정
app.add_middleware( )
from fastapi import FastAPI, Request def add_middleware(app: FastAPI): app.add_middleware( TrustedHostMiddleware, allowed_hosts=["yourdomain.com", "*.yourdomain.com"] )add_middleware()- FastAPI에 내장된 함수
미들웨어 개념
- Spring Boot의 Filter/Interceptor
@Component public class LoggingFilter implements Filter { @Override public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) { // 요청 전 처리 chain.doFilter(request, response); // 다음 필터/컨트롤러로 // 응답 후 처리 } } - FastAPI의 미들웨어
@app.middleware("http") async def logging_middleware(request, call_next): # 요청 전 처리 response = await call_next(request) # 다음 미들웨어/라우터로 # 응답 후 처리 return response- 미들웨어는 요청/응답을 가로채서 처리하는 중간 계층
- 인증, 로깅, CORS, 압축 등
- Spring Boot의 Filter/Interceptor
TrustedHostMiddleware
from fastapi.middleware.trustedhost import TrustedHostMiddleware def add_middleware(app: FastAPI): app.add_middleware( TrustedHostMiddleware, allowed_hosts=["yourdomain.com", "*.yourdomain.com"] )- 악의적인 Host 헤더로 인한 캐시 중독, 리다이렉트 공격을 방지
Spring Security에서 Host Header 공격 방지와 비슷한 개념
@Bean public SecurityFilterChain filterChain(HttpSecurity http) { http.headers().frameOptions().sameOrigin(); }
logger
import logging logger = logging.getLogger(name)- **
logging.getLogger(__**name**__)**- 로거 인스턴스 생성
- Spring의 LoggerFactory 같은 것
__name__- 현재 모듈명 (
"app.core.middleware") - 모듈별로 구분된 로거를 만들어 로그 출처 식별 가능
- 현재 모듈명 (
- **
@app.middleware("http")
from fastapi import FastAPI, Request import time import logging def add_middleware(app: FastAPI): @app.middleware("http") async def log_requests(request: Request, call_next): start_time = time.time() response = await call_next(request) process_time = time.time() - start_time logger.info( f"{request.method} {request.url.path} - " f"Status: {response.status_code} - " f"Time: {process_time:.4f}s" ) return response@app.middleware("http")- 커스텀 HTTP 미들웨어를 등록하는 데코레이터
- Spring Boot의 HandlerInterceptor를 함수로 구현한 것
Request- Http Request 객체
- Spring Boot의 HttpServletRequest와 동일한 역할
call_nextasync def middleware(request: Request, call_next): # 전처리 response = await call_next(request) # 다음 미들웨어/라우터로 전달 # 후처리 return responsecall_next는 다음 미들웨어나 실제 라우터 함수를 호출하는 함수실행 순서
- 미들웨어 A 전처리
call_next()→ 미들웨어 B 전처리call_next()→ 실제 라우터 함수 실행- 미들웨어 B 후처리
- 미들웨어 A 후처리
Spring Boot의 FilterChain.doFilter()와 동일한 개념
public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) { // 전처리 chain.doFilter(request, response); // 다음 필터/컨트롤러로 전달 // 후처리 }
app/main.py
from fastapi import FastAPI
from app.core.config import settings
from app.core.middleware import add_middleware
app = FastAPI(
title="Holiday Keeper API",
description="공휴일 정보 관리 API",
version="1.0.0",
docs_url="/docs",
redoc_url="/redoc"
)
add_middleware(app)
@app.get("/")
async def root():
return {"message": "Holiday Keeper API is running!"}
@app.get("/health")
async def health_check():
return {
"status": "healthy",
"service": "holiday-keeper-api",
"version": "1.0.0"
}
@app.get("/test")
async def test():
return {
"message": "테스트 성공",
"data": ["공휴일1", "공휴일2", "공휴일3"]
}
@app.get("/config")
async def get_config():
if not settings.is_dev:
return {"error": "개발 환경에서만 사용 가능"}
return {
"project_name": settings.project_name,
"version": settings.version,
"environment": settings.environment,
"server_host": settings.server_host,
"server_port": settings.server_port,
"cors_origins": settings.cors_origins_list,
"log_level": settings.log_level
}
middleware.py의
add_middleware(app)를 활용하여 관심사 분리middleware.py 도입 전
from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware app = FastAPI( title="Holiday Keeper API", description="공휴일 정보 관리 API", version="1.0.0", docs_url="/docs", redoc_url="/redoc" ) app.add_middleware( CORSMiddleware, allow_origins=settings.cors_origins_list, allow_credentials=True, allow_methods=["*"], allow_headers=["*"] )
2.3. DB 연결 설정 및 간단한 API 구현하기
app/core/database.py
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from app.core.config import settings
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
sync_engine = create_engine(
settings.sync_database_url,
pool_pre_ping=True,
echo=settings.is_dev
)
async_engine = create_async_engine(
settings.database_url,
pool_pre_ping=True,
echo=settings.is_dev
)
SessionLocal = sessionmaker(
bind=sync_engine,
autocommit=False,
autoflush=False,
)
AsyncSessionLocal = async_sessionmaker(
bind=async_engine,
autocommit=False,
autoflush=False,
expire_on_commit=False
)
async def get_db():
async with AsyncSessionLocal() as db:
yield db
def get_sync_db():
sync_db = SessionLocal()
try:
yield sync_db
finally:
sync_db.close()
DB Engine
from sqlalchemy import create_engine
from sqlalchemy.ext.asyncio import create_async_engine
from app.core.config import settings
sync_engine = create_engine(
settings.sync_database_url,
pool_pre_ping=True,
echo=settings.is_dev
)
async_engine = create_async_engine(
settings.database_url,
pool_pre_ping=True,
echo=settings.is_dev
)
Engine
- 데이터베이스 Connection Pool을 관리하는 객체
- 데이터베이스 연결 생성 및 연결 재사용
- 엔진은 전역에서 하나만 생성해서 공유하는 것이 권장된다
- 엔진마다 별도의 연결 풀을 생성하면 낭비이다.
- 또 엔진 생성에는 비용이 들기 때문에 좋지 않다.
create_engine 함수 (동기 엔진)
- 동기 방식으로 엔진 생성
- 블로킹 방식
- 각 쿼리가 완료될 때까지 대기
create_async_engine 함수 (비동기 엔진)
- 비동기 방식으로 엔진 생성
- 논블로킹 방식
- 쿼리 실행 중에 다른 작업 수행 가능
engine 파라미터
url: str | URL- 필수 인자
- 데이터베에스 연결 URL
pool_pre_ping=True- connection pool에서 connection을 가져올 때 먼저 ping을 보내 connection 상태 확인
echo=settings.is_dev- 개발 모드일 때 실행되는 SQL 쿼리를 콘솔에 출력
settings.is_dev는 임의로 만든 환경 확인 함수
pool_size- 연결 풀의 기본 크기
- default :
5
max_overflow- 풀이 가득 찰 때 추가로 생성할 수 있는 연결 수
- default :
10
pool_timeout- 연결을 얻기 위해 대기할 최대 시간(초)
pool_recycle- 연결의 최대 생존 시간(초)
- 여기서 정해진 시간이 지남에 따라 연결을 폐기 후 새로 생성한다.
DB 세션 생성
from sqlalchemy.orm import sessionmaker
from sqlalchemy.ext.asyncio import async_sessionmaker
SessionLocal = sessionmaker(
bind=sync_engine, # 사용할 엔진
autocommit=False, # 자동 커밋 비활성화
autoflush=False, # 자동 플러시 비활성화
)
AsyncSessionLocal = async_sessionmaker(
bind=async_engine, # 비동기 엔진 바인딩
autocommit=False, # 자동 커밋 비활성화
autoflush=False, # 자동 플러시 비활성화
expire_on_commit=False # 커밋 후 객체 만료 방지
)
SessionLocal
- 세션 팩토리 객체
SessionLocal()을 호출할 때마다 새로운 세션 인스턴스 생성- 각 요청/트랜잭션마다 독립적인 세션 사용
sessionmaker
- 세션을 생성하는 팩토리 함수
sessionmaker- (동기)async_sessionmaker- (비동기)- 호출할 때마다 새로운 세션 인스턴스 생성
session 파라미터
bind- 필수 인자
- 연결할 엔진 지정
autocommit=False- ORM 세션 레벨에서의 autocommit 처리
- 명시적으로
commit()을 호출해야 변경사항이 저장된다 - 반드시
false로 해야 트랜잭션 관리를 할 수 있다 true인 경우, 즉시 DB에 저장되어 롤백 불가능!
autoflush=Falsetrue: 기본값False로 설정해야 ****쿼리 전에 자동으로 flush하지 않는다- 즉, 수동으로 제어하는 것으로 성능상 유리하다.
expire_on_committrue: 기본값False로 설정해야 ****커밋 후에도 객체의 속성에 접근 가능expire_on_commit=False가 비동기에서 중요- 세션이 닫힌 후에도 해당 객체를 활용하기 위해서
class_- 사용할 세션 클래스 지정
- 생략 시 기본값으로 지정된다
sessionmaker⇒ Sessionasync_sessionmaker⇒ AsyncSession
- 커스템 세션 클래스의 경우 명시적으로 지정
infoSessionLocal = sessionmaker( bind=engine, info={ 'app_name': 'my_app', 'version': '1.0.0', 'debug_mode': True } ) # 저장 정보 사용 방법 session.info['key']- 추가 세션 메타데이터 딕셔너리
DB 세션 생성
async def get_db():
async with AsyncSessionLocal() as db:
yield db
def get_sync_db():
sync_db = SessionLocal()
try:
yield sync_db
finally:
sync_db.close()
async with 문법
async with AsyncSessionLocal() as session: # 세션 사용 pass # 자동으로 session.close() 호출됨async with는 비동기 컨텍스트 매니저- 즉, 세션의 생명주기 관리로 세션 생성/정리를 자동으로 처리
- 진입 시 :
AsyncSessionLocal()로 세션 생성 - 종료 시 : 자동으로
await session.close()호출 - 예외 발생 시 : 자동으로 정리 작업 수행
- 일반
with와 비슷하지만__aenter__와__aexit__메서드가 비동기
try, yield, finally 문법
async def get_db(): async with AsyncSessionLocal() as session: try: yield session # 세션을 제공 finally: await session.close() # 정리 작업yield- 이 함수를 제너레이터로 만든다
- FastAPI가 의존성 주입할 때 필요한 문법
try- 세션 사용 중 예외가 발생할 수 있는 구간
finally- 예외 발생 여부와 상관없이 반드시 실행되는 정리 코드
- FastAPI에서는 의존성 주입 시 이 패턴을 사용한다
- 요청 시작 시 세션 생성
- 요청 처리 중 세션 사용
- 요청 완료 후 세션 정리
app/common/models/base.py (Base 모델)
from sqlalchemy import Column, Integer, DateTime, func
from sqlalchemy.orm import DeclarativeBase
class Base(DeclarativeBase):
pass
class TimestampMixin:
created_at = Column(DateTime, server_default=func.now(), nullable=False)
updated_at = Column(DateTime, server_default=func.now(), onupdate=func.now(), nullable=False)
class BaseModel(Base, TimestampMixin):
__abstract__ = True
id = Column(Integer, primary_key=True)
DeclarativeBase
- SQLAlchemy 2.0에서 도입된 새로운 베이스 클래스
- 메타클래스 자동 설정
DeclarativeMeta메타클래스를 자동으로 적용- 복잡한 설정을 자동으로 적용
- 타입 힌팅 개선
- Python의 타입 시스템과 더 잘 통합 (IDE에서 자동완성 개선)
- 타입 검사 지원
- 레지스트리 관리
- SQLAlchemy가 모든 테이블 정의를 중앙에서 관리
registry: 매핑된 클래스들의 레지스트리
- 스키마 생성
- 실제 데이터베이스에 테이블을 만들거나 삭제할 때 사용
Base.metadata.create_all(engine)
metadata객체를 통해 테이블 생성/삭제 관리
- 실제 데이터베이스에 테이블을 만들거나 삭제할 때 사용
- 메타클래스 자동 설정
이전 방식
# SQLAlchemy 1.x from sqlalchemy.ext.declarative import declarative_base Base = declarative_base() # SQLAlchemy 2.0+ from sqlalchemy.orm import DeclarativeBase class Base(DeclarativeBase): pass
- SQLAlchemy 2.0에서 도입된 새로운 베이스 클래스
네이밍 관례
Base
- SQLAlchemy의 루트 베이스 클래스
- 모든 모델의 최상위 부모
- 보통
Base또는DeclarativeBase로 명명
BaseModel
- 공통 필드와 로직을 담는 추상 베이스 모델
Base를 상속받아 확장한다
TimestampMixin
- 믹스인 패턴 : 특정 기능을 여러 클래스에 주입
- 보통
created_at,updated_at같은 공통 시간 필드 제공 - 이름에
Mixin접미사로 목적 명시
__abstract__ = True
- __abstract__의 의미
이 클래스는 실제 테이블을 생성하지 않는다고 설정
- 다른 모델들이 상속받기 위한 공통 속성 정의 역할
- Spring Boot의 BaseEntity와 유사
예시
class TimestampMixin: created_at = Column(DateTime, server_default=func.now(), nullable=False) updated_at = Column(DateTime, server_default=func.now(), onupdate=func.now(), nullable=False) class BaseModel(Base, TimestampMixin): __abstract__ = True # 테이블 생성 안 함 id = Column(Integer, primary_key=True) class Country(BaseModel): # BaseModel 상속 __tablename__ = 'countries' # 실제 테이블 생성 name = Column(String(100))BaseModel⇒ 테이블 생성 안 됨Country⇒countries테이블 생성id,created_at,updated_at,name포함
- __abstract__의 의미
func
- SQLAlchemy의 SQL 함수 네임스페이스
날짜/시간
func.now() # 현재 시간 func.current_date() # 현재 날짜 func.current_timestamp() # 현재 타임스탬프 func.date_part('year', date_col) # 날짜 부분 추출집계 함수
func.count(User.id) # COUNT func.sum(Order.amount) # SUM func.avg(Product.price) # AVERAGE func.min(User.age) # MIN func.max(User.age) # MAX문자열 함수
func.upper(User.name) # 대문자 변환 func.lower(User.email) # 소문자 변환 func.length(User.name) # 문자열 길이 func.concat(User.first_name, ' ', User.last_name) # 문자열 연결수학 함수
func.abs(column) # 절댓값 func.round(column, 2) # 반올림 func.floor(column) # 내림 func.ceiling(column) # 올림
app/domains/country/models.py (Country 모델)
from sqlalchemy import Column, Integer, String, DateTime
from sqlalchemy.orm import relationship
from app.common.models.base import BaseModel
class Country(BaseModel):
__tablename__ = 'countries'
country_code = Column(String(10), unique=True, nullable=False)
name = Column(String(100), nullable=False)
holidays = relationship("Holiday", back_populates="country")
Column( ) 함수
Column(type_, *args, **kwargs)데이터 타입
Column(Integer) # 정수형 Column(String(100)) # 최대 100자 문자열 Column(Text) # 긴 텍스트 Column(DateTime) # 날짜/시간 Column(Boolean) # 불린 Column(Float) # 실수 Column(DECIMAL(10, 2)) # 정밀한 소수점제약 조건
Column(Integer, primary_key=True) # 기본키 Column(String(50), nullable=False) # NOT NULL Column(String(100), unique=True) # UNIQUE 제약 Column(Integer, default=0) # 기본값 Column(DateTime, server_default=func.now()) # 서버 기본값- 서버 기본값이란
- DB 레벨의 기본값을 말한다.
- 일반 기본값의 경우, 쿼리를 넘길 때 기본값으로 설정하여 내보내는 것이다.
- 서버 기본값은 쿼리에서 생략해도 DB 서버가 알아서
NOW()로 입력한다.
- 서버 기본값이란
인덱스와 성능
Column(String(50), index=True) # 인덱스 생성 Column(Integer, primary_key=True) # PK는 자동으로 인덱스- PK와 UK는 자동으로 인덱스 설정된다
- FK는 DB마다 다르다
업데이트 관련
Column(DateTime, onupdate=func.now()) # 업데이트 시 자동 갱신 Column(DateTime, server_onupdate=func.now()) # 서버 레벨 업데이트외래키
Column(Integer, ForeignKey('users.id')) # 외래키 설정- ForeignKey 설정에는 실제 테이블명 사용
기타 옵션
Column(String(100), comment='사용자 이름', # 주석 info={'label': 'Username'}, # 추가 메타데이터 quote=True, # 컬럼명 따옴표 처리 autoincrement=True) # 자동 증가 (PK의 기본값)autoincrement=True- DB별 자동 증가 전략 선택
- 다만, PK + Integer의 경우 자동 증가가 자동 적용된다.
quote=True예약어나 특수문자가 포함된 컬럼명인 경우 설정
⇒ 따옴표 처리로 가능하게 한다
relationship( )
기본 관계 설정
# 현재 코드 holidays = relationship("Holiday", back_populates="country") relationship("User") # 기본 관계 relationship("User", backref="orders") # 역참조 자동 생성 relationship("User", back_populates="orders") # 양방향 명시적 설정relationship("User", XXX)- 대상 클래스는 모델 클래스명을 사용한다.
외래키 관련
relationship("User", foreign_keys=[user_id]) # 외래키 명시 relationship("User", primaryjoin="Order.user_id==User.id") # 조인 조건 명시로딩 전략
relationship("User", lazy="select") # 기본: 필요시 로딩 relationship("User", lazy="joined") # 조인으로 즉시 로딩 relationship("User", lazy="subquery") # 서브쿼리로 로딩 relationship("User", lazy="dynamic") # 쿼리 객체 반환 relationship("User", lazy="selectin") # IN 절 사용해서 로딩- Default :
lazy="select"- 현업에서 가장 많이 사용
- 기본값 사용 시 생략 가능
- 테이블 정의는 기본값으로 두고, 쿼리할 때 상황에 맞게 제어
- 성능 최적화가 필요할 때 ⇒ N+1 문제 해결
lazy="joined"lazy="selectin"
- Default :
컬렉션 타입
relationship("Tag", collection_class=set) # Set으로 관리 relationship("Tag", collection_class=list) # List로 관리 (기본값)- Default :
collection_class=list
- Default :
삭제 동작
relationship("Order", cascade="all, delete-orphan", # 연쇄 삭제 passive_deletes=True) # DB 레벨 삭제cascade="all”은 여러 cascade 옵션의 조합이다.save-update: 부모 저장 시 자식도 저장merge: 부모 병합 시 자식도 병합refresh: 부모 새로고침 시 자식도 새로고침expunge: 부모 세션에서 제거 시 자식도 제거delete: 부모 삭제 시 자식도 삭제
- delete-orphan
- 부모와 연결이 끊어진 자식 자동 삭제
user = session.query(User).first() order = user.orders[0] user.orders.remove(order) # 관계에서 제거 session.commit()- Order 객체가 자동으로 DB에서 삭제됨
- 부모와 연결이 끊어진 자식 자동 삭제
- passive_deletes=True
- Default :
passive_deletes=False False는 삭제 시 SQLAlchemy가 SELECT로 모든 자식들을 찾아 하나씩 DELETE- 느리고 메모리 많이 사용
True는 삭제 시 DB의 CASCADE가 알아서 관련 자식 삭제- 빠르다
- Default :
정렬
relationship("Comment", order_by="Comment.created_at.desc()") # 정렬 조건- JPA의 @OrderBy와 동일
- 데이터를 가져올 때 정렬 조건
조건부 관계
class Order(BaseModel): # 모든 상품 product = relationship("Product") # 활성화된 상품만 active_product = relationship("Product", primaryjoin="and_(Order.product_id==Product.id, " "Product.is_active==True)") # 사용 order = session.query(Order).first() print(order.product) # 모든 상품 print(order.active_product) # 활성화된 상품만 (조건에 맞는 것만)양방향 관계
# Country 모델 class Country(BaseModel): holidays = relationship("Holiday", back_populates="country") # Holiday 모델 class Holiday(BaseModel): country_id = Column(Integer, ForeignKey('countries.id')) country = relationship("Country", back_populates="holidays")- 직접 FK 설정을 해줘야 실제 연결이 된다
- 즉, 단방향이든 양방향이든 FK 설정이 완료되어야지 실제 연결된다
- 직접 FK 설정을 해줘야 실제 연결이 된다
app/domains/holiday/models.py (Holiday 모델)
from sqlalchemy import Column, Integer, String, Date, Boolean, UniqueConstraint, Index, ForeignKey
from sqlalchemy.orm import relationship
from app.common.models.base import BaseModel
class Holiday(BaseModel):
__tablename__ = 'holidays'
__table_args__ = (
UniqueConstraint('country_id', 'date', 'name', name='uk_1'),
Index('idx_year', 'holiday_year'),
Index('idx_country_year', 'country_id', 'holiday_year'),
)
date = Column(Date, nullable=False)
name = Column(String(200))
local_name = Column(String(200))
holiday_year = Column(Integer, nullable=False)
launch_year = Column(Integer)
is_global = Column(Boolean, default=False)
types = Column(String(100))
counties = Column(String(500))
country_id = Column(Integer, ForeignKey('countries.id'), nullable=False, index=True)
country = relationship("Country", back_populates="holidays")
1대다 관계 설정
Spring Boot의 경우
- 어노테이션 기반으로 관계 정의
mappedBy로 연관관계 주인이 아님을 명시- 자동으로 FK는 상대방이 가진다
FastAPI의 경우
- FK 필드를 통해 관계를 결정
- FK를 가진 테이블이 연관관계의 주인이 된다
back_populates vs backref
둘 모두 양방향 설정으로 양쪽에서 모두 연관 대상에 접근할 수 있다
backref
- 한쪽 모델에서만 관계를 정의한다
- 다른 쪽 모델은 relationship 정의 안 하지만, 자동 생성된다.
- 간단한 프로젝트에서 주로 사용
- 한쪽 모델에서만 관계를 정의한다
back_populates
- 양쪽 모델 모두에서 관계를 정의한다
- 명시적으로 설정하여 각 관계가 어떻게 설정되어 있는지 명확하다.
- 또한 각 관계마다 다른 옵션 설정 가능
- IDE 지원 : 자동완성과 타입 검사 더 잘됨
다대다 관계 설정
단순 다대다 관계 (추가 필드 없는 경우)
Spring Boot의 경우
@Entity public class User { @ManyToMany @JoinTable( name = "user_roles", joinColumns = @JoinColumn(name = "user_id"), inverseJoinColumns = @JoinColumn(name = "role_id") ) private List<Role> roles; } @Entity public class Role { @ManyToMany(mappedBy = "roles") private List<User> users; }FastAPI의 경우
from sqlalchemy import Table user_role_table = Table( 'user_roles', Base.metadata, Column('user_id', Integer, ForeignKey('users.id')), Column('role_id', Integer, ForeignKey('roles.id')) ) class User(BaseModel): __tablename__ = 'users' name = Column(String(100)) roles = relationship("Role", secondary=user_role_table, back_populates="users") class Role(BaseModel): __tablename__ = 'roles' name = Column(String(50)) users = relationship("User", secondary=user_role_table, back_populates="roles")- 중간 테이블 정의
복잡한 다대다 관계 (중간 테이블에 추가 필드)
Spring Boot의 경우
@Entity public class UserRole extends BaseEntity { @ManyToOne(fetch = FetchType.LAZY) @JoinColumn(name = "user_id") private User user; @ManyToOne(fetch = FetchType.LAZY) @JoinColumn(name = "role_id") private Role role; private LocalDateTime assignedDate; } @Entity public class User extends BaseEntity { @OneToMany(mappedBy = "user", cascade = CascadeType.ALL, orphanRemoval = true) private List<UserRole> userRoles = new ArrayList<>(); } @Entity public class Role extends BaseEntity { @OneToMany(mappedBy = "role", cascade = CascadeType.ALL, orphanRemoval = true) private List<UserRole> roleUsers = new ArrayList<>(); }FastAPI의 경우
from sqlalchemy import DateTime, func class UserRole(BaseModel): __tablename__ = 'user_roles' user_id = Column(Integer, ForeignKey('users.id'), primary_key=True) role_id = Column(Integer, ForeignKey('roles.id'), primary_key=True) assigned_date = Column(DateTime, server_default=func.now()) user = relationship("User", back_populates="user_roles") role = relationship("Role", back_populates="role_users") class User(BaseModel): __tablename__ = 'users' name = Column(String(100)) user_roles = relationship("UserRole", back_populates="user") class Role(BaseModel): __tablename__ = 'roles' name = Column(String(50)) role_users = relationship("UserRole", back_populates="role")
app/common/crud/base.py
from typing import Type, TypeVar, Any, Optional, List
from sqlalchemy import select, func, text
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import DeclarativeBase
ModelType = TypeVar("ModelType", bound=DeclarativeBase)
class CRUDBase:
def __init__(self, model: Type[ModelType]):
self.model = model
async def get(self, db: AsyncSession, obj_id: Any) -> Optional[ModelType]:
return await db.get(self.model, obj_id)
async def get_by_field(self, db: AsyncSession, field_name: str, value: Any) -> Optional[ModelType]:
stmt = select(self.model).where(getattr(self.model, field_name) == value)
result = await db.execute(stmt)
return result.scalars().one_or_none()
async def get_by_fields(self, db: AsyncSession, **kwargs: Any) -> Optional[ModelType]:
conditions = [getattr(self.model, key) == value for key, value in kwargs.items()]
stmt = select(self.model).where(*conditions)
result = await db.execute(stmt)
return result.scalars().one_or_none()
async def get_multi(self, db: AsyncSession, *, skip: int = 0, limit: int = 100) -> List[ModelType]:
stmt = select(self.model).offset(skip).limit(limit)
result = await db.execute(stmt)
return result.scalars().all()
async def get_multi_by_fields(self, db: AsyncSession, *, skip: int = 0, limit: int = 100, **kwargs: Any) -> List[ModelType]:
conditions = [getattr(self.model, key) == value for key, value in kwargs.items()]
stmt = select(self.model).where(*conditions).offset(skip).limit(limit)
result = await db.execute(stmt)
return result.scalars().all()
async def get_all(self, db: AsyncSession) -> List[ModelType]:
stmt = select(self.model)
results = await db.execute(stmt)
return results.scalars().all()
async def count(self, db: AsyncSession) -> int:
stmt = select(func.count()).select_from(self.model)
result = await db.execute(stmt)
return result.scalar_one()
async def create(self, db: AsyncSession, *, obj_in: Any) -> ModelType:
db_obj = self.model(**obj_in.model_dump())
db.add(db_obj)
await db.commit()
await db.refresh(db_obj)
return db_obj
async def update(self, db: AsyncSession, db_obj: ModelType, obj_in: Any) -> ModelType:
update_data = obj_in.model_dump(exclude_unset=True)
for key, value in update_data.items():
setattr(db_obj, key, value)
db.add(db_obj)
await db.commit()
await db.refresh(db_obj)
return db_obj
async def remove(self, db: AsyncSession, *, db_obj: ModelType) -> ModelType:
await db.delete(db_obj)
await db.commit()
return db_obj
초기 설정
TypeVar
from typing import Type, TypeVar, Any, Optional, List from sqlalchemy.orm import DeclarativeBase ModelType = TypeVar("ModelType", bound=DeclarativeBase) class CRUDBase: def __init__(self, model: Type[ModelType]): self.model = modelTypeVar란
- Python의 타입 힌팅 시스템에서 제네릭 타입 변수를 정의하는 문법
"ModelType"- TypeVar의 이름을 문자열로 지정
- 관례적으로 변수명과 동일하게 설정
bound=DeclarativeBase- TypeVar가 가질 수 있는 타입의 upper bound을 제한
- 즉, 이 경우 DeclarativeBase를 상속받은 클래스만 올 수 있다
- SQLAlchemy ORM 모델들은 모두
DeclarativeBase를 상속받는다
- SQLAlchemy ORM 모델들은 모두
그 외 파라미터
# 제약 없는 TypeVar T = TypeVar('T') # 특정 타입들로만 제한 NumberType = TypeVar('NumberType', int, float, complex) # 공변성/반공변성 설정 T_co = TypeVar('T_co', covariant=True) # 공변 T_contra = TypeVar('T_contra', contravariant=True) # 반공변 # bound와 constraints 조합 UserType = TypeVar('UserType', bound=BaseModel, int, str)현업 명명 사례
# 모델/엔티티 관련 ModelType = TypeVar('ModelType', bound=BaseModel) EntityT = TypeVar('EntityT', bound=Entity) DomainT = TypeVar('DomainT') # 서비스/리포지토리 패턴 ServiceT = TypeVar('ServiceT', bound=BaseService) RepoT = TypeVar('RepoT', bound=BaseRepository) # HTTP 관련 RequestT = TypeVar('RequestT', bound=BaseRequest) ResponseT = TypeVar('ResponseT', bound=BaseResponse) # 일반적인 명명 T = TypeVar('T') # 가장 일반적 K = TypeVar('K') # Key 타입 V = TypeVar('V') # Value 타입
__init__
- __init__ 함수는 Python의 생성자 메서드
__init__ 함수
- Python의 인스턴스 초기화 메서드 (생성자)
- 객체가 생성될 때 자동으로 호출
- 인스턴스의 초기 상태를 설정하는 역할
self- 인스턴스 자신을 가리키는 참조
- 인스턴스 메서드의 첫 번째 파라미터로 자동 전달
- 정적 메서드가 아닌 인스턴스 메서드임을 의미
model: Type[ModelType]Type[ModelType]- 제네릭 타입 변수 ModelType의 타입을 받는다는 타입 힌트
- 인스턴스가 아닌 클래스 자체를 받는다
단건 / 다건 조회
SQLAlchemy ORM
- 클래스는 테이블의 메타데이터를 담고 있는 설계도 역할을 한다.
- 객체를 통해 메서드를 호출하지만 내부에서는 클래스 자체만을 사용
단건 조회
async def get(self, db: AsyncSession, obj_id: Any) -> Optional[ModelType]: return await db.get(self.model, obj_id) async def get_by_field(self, db: AsyncSession, field_name: str, value: Any) -> Optional[ModelType]: stmt = select(self.model).where(getattr(self.model, field_name) == value) result = await db.execute(stmt) return result.scalars().one_or_none() async def get_by_fields(self, db: AsyncSession, **kwargs: Any) -> Optional[ModelType]: conditions = [getattr(self.model, key) == value for key, value in kwargs.items()] stmt = select(self.model).where(*conditions) result = await db.execute(stmt) return result.scalars().one_or_none()SQLAlchemy의 get( ) 함수
- Primary Key로 단일 객체 조회하는 SQLAlchemy 메서드
db.get(self.model, obj_id)- 첫 번째 인자 : 조회할 모델 클래스
- 두 번째 인자 : Primary Key 값
- 반환 : 해당 객체 또는 None
SQLAlchemy의 select( ) 함수
stmt = select(User) # SELECT * FROM user stmt = select(User.name, User.email) # SELECT name, email FROM user- 쿼리 빌더 함수로, SQL SELECT 문을 생성하는 SQLAlchemy 메서드
SQLAlchemy의 where( ) 함수
# 단일 조건 stmt = select(User).where(User.id == 1) # 여러 조건 (AND) stmt = select(User).where(User.active == True, User.age > 18) # 또는 (OR) stmt = select(User).where(User.active == True).where(User.age > 18)- SELECT 문에 WHERE 조건을 추가하는 SQLAlchemy 메서드
getattr( ) 함수
class User: name = "John" age = 25 # 일반적인 속성 접근 user.name # "John" # getattr을 사용한 동적 속성 접근 getattr(user, "name") # "John" getattr(user, "salary", 0) # 기본값 0 (속성이 없을 때)- 객체의 속성을 동적으로 가져오는 Python 내장 함수
- 해당 속성이 없을 때 기본값을 지정해서 가져올 수 있다.
**kwargs 문법
def get_by_fields(self, db: AsyncSession, **kwargs: Any): # kwargs = {"name": "John", "age": 25, "active": True} print(kwargs) # {'name': 'John', 'age': 25, 'active': True} # 딕셔너리 순회 for key, value in kwargs.items(): print(f"{key}: {value}") # 사용 예시 repo.get_by_fields(db, name="John", age=25, active=True)- 키워드 인자들을 딕셔너리로 받는 Python 문법
- 딕셔너리 순회를 통해 많이 사용한다.
- dict 클래스의 메서드인
items()등 활용
- dict 클래스의 메서드인
- 2개 이상의 필드로 조건부 조회하는 경우 정확한 숫자를 몰라도 설정 가능
SQLAlchemy의 execute( ) 함수
# 비동기 실행 result = await db.execute(stmt) # 동기 실행 result = db.execute(stmt)데이터베이스에 쿼리를 실행하는 SQLAlchemy 메서드
execute()결과의 기본 구조# 직접 fetchall() 사용 rows = result.fetchall() print(type(rows[0])) # <class 'sqlalchemy.engine.Row'> # Row 객체 접근 for row in rows: # 방법 1: 인덱스로 접근 print(row[0]) # 첫 번째 컬럼 # 방법 2: 속성명으로 접근 print(row.id) # User.id print(row.name) # User.name # 방법 3: 딕셔너리처럼 접근 print(row['id']) print(row['name']) # Row를 딕셔너리로 변환 user_dict = row._asdict() print(user_dict) # {'id': 1, 'name': 'John', 'email': 'john@example.com'}- Row 객체는 ORM 모델 객체의 메서드를 사용할 수 없다.
SQLAlchemy의 scalars( ) 함수
# ORM 객체 result = await db.execute(select(User)) users = result.scalars().all() user_obj = users[0] print(type(user_obj)) user_obj.some_method() # User 모델의 메서드 사용 가능Result 객체에서 스칼라 값들만 추출하는 SQLAlchemy 메서드
scalars( ) 대신 사용할 수 있는 메서드들
mappings()- 딕셔너리 형태로 결과 반환result = await db.execute(select(User)) user_dicts = result.mappings().all() # [{'id': 1, 'name': 'John', 'email': 'john@example.com'}, ...] # API 응답에 바로 사용 가능 return JSONResponse(content=user_dicts)튜플 결과
result = await db.execute(select(User.id, User.name)) tuples = result.all() # [(1, 'John'), (2, 'Jane'), ...] # 또는 scalars().all()로 첫 번째 컬럼만 result = await db.execute(select(User.name)) names = result.scalars().all() # ['John', 'Jane', ...]
뒤에 올 수 있는 메서드
# 모든 결과 가져오기 .all() # 리스트로 모든 결과 반환 .fetchall() # all()과 동일 # 단일 결과 가져오기 .first() # 첫 번째 결과 또는 None .one() # 정확히 하나의 결과 (없거나 여러 개면 예외) .one_or_none() # 하나의 결과 또는 None (여러 개면 예외) # 고유값들만 .unique() # 중복 제거 # 파티셔닝 .partitions(size) # 지정된 크기로 결과를 나눔- all( ) 메서드의 특징
- 데이터가 없는 경우 빈 리스트
[]반환 (None이 아님) - 항상 리스트 타입 보장
- 데이터가 없는 경우 빈 리스트
- all( ) 메서드의 특징
SQLAlchemy의 scalar( ) 함수
- 쿼리 결과의 첫 번째 행, 첫 번째 컬럼의 값을 추출
result.scalar() # 첫 번째 행의 첫 번째 컬럼 (None 가능) result.scalar_one() # 정확히 1개 값 (0개나 2개 이상이면 예외) result.scalar_one_or_none() # 1개 값 또는 None (2개 이상이면 예외) 사용 대상
- COUNT, SUM, MAX, MIN 등 집계 함수 결과
- 단일 값을 반환하는 쿼리
- 존재 여부 확인 (
EXISTS)
scalar( )와 scalars( )의 차이점
scalar()- 단일 값만 반환scalars().first()- ORM 객체 하나scalars().all()- ORM 객체 리스트- 즉, 일반 상세 조회 등에서는 scalar( )를 쓰지 않는다. 객체가 필요하기 때문에
사용 예시
# 1. COUNT - 가장 많이 사용 async def get_user_count(db: AsyncSession) -> int: result = await db.execute(select(func.count(User.id))) return result.scalar() # 정수 반환 (예: 150) # 2. 조건부 COUNT async def get_active_user_count(db: AsyncSession) -> int: result = await db.execute( select(func.count(User.id)).where(User.active == True) ) return result.scalar() # 3. MAX, MIN 값 조회 async def get_latest_user_id(db: AsyncSession) -> Optional[int]: result = await db.execute(select(func.max(User.id))) return result.scalar() # 최대 ID 값 # 4. SUM 계산 async def get_total_order_amount(db: AsyncSession, user_id: int) -> float: result = await db.execute( select(func.sum(Order.amount)).where(Order.user_id == user_id) ) return result.scalar() or 0.0 # 5. EXISTS 확인 (매우 자주 사용) async def email_exists(db: AsyncSession, email: str) -> bool: result = await db.execute( select(exists().where(User.email == email)) ) return result.scalar() # True/False # 6. 단일 컬럼 값 하나만 가져오기 async def get_user_name(db: AsyncSession, user_id: int) -> Optional[str]: result = await db.execute( select(User.name).where(User.id == user_id) ) return result.scalar() # 이름 문자열 또는 None
- 쿼리 결과의 첫 번째 행, 첫 번째 컬럼의 값을 추출
다건 조회
async def get_multi(self, db: AsyncSession, *, skip: int = 0, limit: int = 100) -> List[ModelType]: stmt = select(self.model).offset(skip).limit(limit) result = await db.execute(stmt) return result.scalars().all() async def get_multi_by_fields(self, db: AsyncSession, *, skip: int = 0, limit: int = 100, **kwargs: Any) -> List[ModelType]: conditions = [getattr(self.model, key) == value for key, value in kwargs.items()] stmt = select(self.model).where(*conditions).offset(skip).limit(limit) result = await db.execute(stmt) return result.scalars().all() async def get_all(self, db: AsyncSession) -> List[ModelType]: stmt = select(self.model) results = await db.execute(stmt) return results.scalars().all()offset( )과 limit( ) 함수
stmt = select(User).offset(10).limit(5) # 생성되는 SQL: SELECT * FROM user LIMIT 5 OFFSET 10 # 10개를 건너뛰고 5개만 가져와라- SQLAlchemy의 페이징 처리를 위한 SQL 함수들
* 키워드 전용 인자 문법
def func(a, b, *, c, d): pass # 올바른 호출 func(1, 2, c=3, d=4) # ✅ # 잘못된 호출 func(1, 2, 3, 4) # ❌ TypeError: 위치 인자로 c, d 전달 불가* 이후의 매개변수들은 반드시 키워드로 전달해야 한다
키워드 인자와 위치 인자
위치 인자의 경우, 순서에 따라 매개변수에 값이 할당된다.
greet("김철수", 25, "서울") # ✅ greet(25, "서울", "김철수") # ❌ or 의도와 다른 결과키워드 인자의 경우, 매개변수 이름을 명시하여 할당한다.
greet(name="김철수", age=25, city="서울") # ✅ greet(city="서울", name="김철수", age=25) # ✅
이렇게 키워드 인자로 사용하는 이유
users = await crud.get_multi(db, skip=20, limit=10)- 무엇을 건너뛰고 몇 개 가져올지 명확하게 명시하기 위해
집계 함수
async def count(self, db: AsyncSession) -> int:
stmt = select(func.count()).select_from(self.model)
result = await db.execute(stmt)
return result.scalar_one()
async def count_by_fields(self, db: AsyncSession, **kwargs: Any) -> int:
conditions = [getattr(self.model, key) == value for key, value in kwargs.items()]
stmt = select(func.count()).select_from(self.model).where(*conditions)
result = await db.execute(stmt)
return result.scalar_one()
사용 예시
total_users = await user_crud.count(db) # 전체 사용자 수 active_users = await user_crud.count_by_fields(db, active=True) # 활성 사용자 수scalar_one( )
scalar_one()- 오직 단 하나의 값만 나와야 한다
scalar_one()을 사용하는 이유는 COUNT 함수는 항상 숫자를 반환하기 때문- 조건에 맞는 데이터가 없는 경우
null이 아닌0을 반환
다른 집계 함수
# MAX, MIN, SUM, AVG - NULL 가능 SELECT MAX(age) FROM users; -- 데이터 없으면 NULL SELECT SUM(salary) FROM users; -- 데이터 없으면 NULL SELECT AVG(age) FROM users; -- 데이터 없으면 NULL- 그러므로, 다른 집계함수에서는
scalar()를 쓰거나scalar_one_or_none()사용 scalar()는 결과가 여러 개여도 첫번째 것만 가져온다scalar_one_or_none()는 결과가 여러 개라면 예외를 발생시킨다
- 그러므로, 다른 집계함수에서는
SQLAlchemy의 select_from( ) 함수
select(User) # SQL: SELECT * FROM user select(func.count()).select_from(User) # SQL: SELECT COUNT(*) FROM user- FROM 절을 명시적으로 지정하는 함수
일반적인
select()함수는FROM이 자동 추론된다.다만, 집계 함수를 쓰는 경우 FROM 절이 불분명한 경우가 있다
컬럼을 명시하는 경우 ⇒
select_from불필요stmt = select(func.count(User.id))- 제네릭 타입이므로 컬럼을 명시하기 힘들다
서브 쿼리로 FROM 절을 사용하는 경우도
select_from필요stmt = select(func.count()).select_from(select(User).subquery()) # SQL: SELECT COUNT(*) FROM (SELECT * FROM user) AS subquery
- FROM 절을 명시적으로 지정하는 함수
생성 / 수정 / 삭제
async def create(self, db: AsyncSession, *, obj_in: Any) -> ModelType:
db_obj = self.model(**obj_in.model_dump())
db.add(db_obj)
await db.commit()
await db.refresh(db_obj)
return db_obj
async def update(self, db: AsyncSession, db_obj: ModelType, obj_in: Any) -> ModelType:
update_data = obj_in.model_dump(exclude_unset=True)
for key, value in update_data.items():
setattr(db_obj, key, value)
db.add(db_obj)
await db.commit()
await db.refresh(db_obj)
return db_obj
async def remove(self, db: AsyncSession, *, db_obj: ModelType) -> ModelType:
await db.delete(db_obj)
await db.commit()
return db_obj
obj_in
- Pydantic Schemas가 들어오기 때문에
Any로 설정 - Spring Boot에서 DTO를 받아 전환 후 DB 저장하는 패턴과 유사
- Pydantic Schemas가 들어오기 때문에
model_dump( ) 함수
- Pydantic BaseModel 클래스의 메서드
Pydantic 객체를 딕셔너리로 변환한다.
# Pydantic 객체 생성 user_data = UserCreate(name="김철수", email="kim@email.com", age=30) # model_dump() 호출 - 딕셔너리로 변환 print(user_data.model_dump()) # 출력: {'name': '김철수', 'email': 'kim@email.com', 'age': 30, 'active': True}model_dump( )의 매개변수들
- 기본 사용
user_data.model_dump() # {'name': '김철수', 'email': 'kim@email.com', 'age': 30, 'active': True} - exclude_unset=True
user_data.model_dump(exclude_unset=True) # {'name': '김철수', 'email': 'kim@email.com', 'age': 30} # active는 기본값이므로 제외- 설정되지 않은 필드 제외
- 기본값 :
False - 일부 필드만 업데이트하는 경우 중요하다
- include
user_data.model_dump(include={'name', 'email'}) # {'name': '김철수', 'email': 'kim@email.com'}- 특정 필드만 포함
- exclude
user_data.model_dump(exclude={'active'}) # {'name': '김철수', 'email': 'kim@email.com', 'age': 30}- 특정 필드 제외
- by_alias=True
user_data.model_dump(by_alias=True)- 필드 별칭 사용
- exclude_none=True
user_data.model_dump(exclude_none=True)- None 값 제외
- 기본 사용
refresh( ) 함수
- 데이터베이스와 객체를 동기화한다
데이터 삽입 과정
async def create(self, db: AsyncSession, *, obj_in: Any) -> ModelType: # 1. Pydantic → SQLAlchemy 객체 생성 db_obj = self.model(**obj_in.model_dump()) print(db_obj.id) # None (아직 DB에 저장 전) # 2. 세션에 추가 db.add(db_obj) print(db_obj.id) # 여전히 None # 3. DB에 커밋 (실제 저장) await db.commit() print(db_obj.id) # 여전히 None! # 4. 객체 새로고침 (DB에서 최신 데이터 가져옴) await db.refresh(db_obj) print(db_obj.id) # 1 (DB에서 자동 생성된 ID) return db_obj- Pydantic의
model_dump()함수로 딕셔너리 전환 - 딕셔너리를
self.model(해당 모델)의 init 함수로 넘겨 모델 객체 생성 db.add()로 DB 세션에 저장db.commit()을 통해 실제 DB에 저장db.refresh()를 통해 해당 객체의 정보를 DB에서 갱신- Database의 종류가 무엇이든, SQLAlchemy는 DB 저장 후 ID를 가져온다
- Pydantic의
setattr( ) 함수
class User: def __init__(self): self.name = "" self.age = 0 user = User() # 일반적인 속성 설정 user.name = "김철수" user.age = 25 # setattr()을 사용한 동적 속성 설정 setattr(user, "name", "김철수") # user.name = "김철수"와 같음 setattr(user, "age", 25) # user.age = 25와 같음getattr()함수는 동적으로 객체의 속성의 값을 가져오는 것이다setattr()은 동적으로 객체의 속성의 값을 설정하는 것이다
app/common/crud/base.py (수정)
from typing import Type, TypeVar, Any, Optional, List
from sqlalchemy import select, func
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import DeclarativeBase
ModelType = TypeVar("ModelType", bound=DeclarativeBase)
class CRUDBase:
def __init__(self, model: Type[ModelType]):
self.model = model
async def get(self, db: AsyncSession, obj_id: Any) -> Optional[ModelType]:
return await db.get(self.model, obj_id)
async def get_by_field(self, db: AsyncSession, field_name: str, value: Any) -> Optional[ModelType]:
stmt = select(self.model).where(getattr(self.model, field_name) == value)
result = await db.execute(stmt)
return result.scalars().one_or_none()
async def get_by_fields(self, db: AsyncSession, **kwargs: Any) -> Optional[ModelType]:
conditions = [getattr(self.model, key) == value for key, value in kwargs.items()]
stmt = select(self.model).where(*conditions)
result = await db.execute(stmt)
return result.scalars().one_or_none()
async def get_multi(self, db: AsyncSession, *, skip: int = 0, limit: int = 100) -> List[ModelType]:
stmt = select(self.model).offset(skip).limit(limit)
result = await db.execute(stmt)
return result.scalars().all()
async def get_multi_by_fields(self, db: AsyncSession, *, skip: int = 0, limit: int = 100, **kwargs: Any) -> List[ModelType]:
conditions = [getattr(self.model, key) == value for key, value in kwargs.items()]
stmt = select(self.model).where(*conditions).offset(skip).limit(limit)
result = await db.execute(stmt)
return result.scalars().all()
async def get_all(self, db: AsyncSession) -> List[ModelType]:
stmt = select(self.model)
results = await db.execute(stmt)
return results.scalars().all()
async def count(self, db: AsyncSession) -> int:
stmt = select(func.count()).select_from(self.model)
result = await db.execute(stmt)
return result.scalar_one()
async def count_by_fields(self, db: AsyncSession, **kwargs: Any) -> int:
conditions = [getattr(self.model, key) == value for key, value in kwargs.items()]
stmt = select(func.count()).select_from(self.model).where(*conditions)
result = await db.execute(stmt)
return result.scalar_one()
async def create(self, db: AsyncSession, *, obj_in: Any) -> ModelType:
db_obj = self.model(**obj_in.model_dump())
db.add(db_obj)
return db_obj
async def update(self, db: AsyncSession, db_obj: ModelType, obj_in: Any) -> ModelType:
update_data = obj_in.model_dump(exclude_unset=True)
for key, value in update_data.items():
setattr(db_obj, key, value)
db.add(db_obj)
return db_obj
async def remove(self, db: AsyncSession, *, db_obj: ModelType) -> ModelType:
await db.delete(db_obj)
return db_obj
- FastAPI는 트랜잭션을 DB 세션으로 관리한다
- Spring Boot 처럼 관리되지 않고 직접 관리해야 한다.
문제점
- 현재의
CRUDBase클래스는 트랜잭션이 각각 설정되어 있는 것이 문제이다. - 비즈니스 로직 하나에 여러 모델을 생성하거나 수정, 삭제할 수 있는데 전체의 비즈니스 로직이 원자성을 가져야 한다
- 하지만, 이 경우 개별적으로 트랜잭션이 종료되므로 서비스 레벨에서 트랜잭션을 관리해야 한다
- 현재의
기존의 create, update, remove 함수의
commit()과refresh()를 제거단, create 함수의 결과 객체에는 ID가 존재하지 않는다는 것을 명심해야 한다.
- ID가 필요한 경우 Service Layer에서 중간
db.flush()를 해야 한다.
- ID가 필요한 경우 Service Layer에서 중간
app/common/service/base.py
from typing import TypeVar, Any, Optional, List
from sqlalchemy.ext.asyncio import AsyncSession
from app.common.crud.base import CRUDBase, ModelType
ServiceType = TypeVar("ServiceType", bound="BaseService")
class BaseService:
def __init__(self, crud: CRUDBase):
self.crud = crud
async def get_by_id(self, db: AsyncSession, obj_id: Any) -> Optional[ModelType]:
return await self.crud.get(db, obj_id)
async def get_by_field(self, db: AsyncSession, field_name: str, value: Any) -> Optional[ModelType]:
return await self.crud.get_by_field(db, field_name, value)
async def get_by_fields(self, db: AsyncSession, **kwargs: Any) -> Optional[ModelType]:
return await self.crud.get_by_fields(db, **kwargs)
async def get_multi(self, db: AsyncSession, *, skip: int = 0, limit: int = 100) -> List[ModelType]:
return await self.crud.get_multi(db, skip=skip, limit=limit)
async def get_multi_by_fields(self, db: AsyncSession, *, skip: int = 0, limit: int = 100, **kwargs: Any) -> List[ModelType]:
return await self.crud.get_multi_by_fields(db, skip=skip, limit=limit, **kwargs)
async def get_all(self, db: AsyncSession) -> List[ModelType]:
return await self.crud.get_all(db)
async def count(self, db: AsyncSession) -> int:
return await self.crud.count(db)
async def count_by_fields(self, db: AsyncSession, **kwargs: Any) -> int:
return await self.crud.count_by_fields(db, **kwargs)
async def create(self, db: AsyncSession, *, obj_in: Any) -> ModelType:
db_obj = await self.crud.create(db, obj_in=obj_in)
await db.commit()
await db.refresh(db_obj)
return db_obj
async def update(self, db: AsyncSession, db_obj: ModelType, obj_in: Any) -> ModelType:
updated_obj = await self.crud.update(db, db_obj=db_obj, obj_in=obj_in)
await db.commit()
await db.refresh(updated_obj)
return updated_obj
async def remove(self, db: AsyncSession, *, db_obj: ModelType) -> ModelType:
removed_obj = await self.crud.remove(db, db_obj=db_obj)
await db.commit()
return removed_obj
간단한 비즈니스 로직을 처리할 수 있는 Base 서비스이다.
async def create(self, db: AsyncSession, *, obj_in: Any) -> ModelType: db_obj = await self.crud.create(db, obj_in=obj_in) await db.commit() await db.refresh(db_obj) return db_obj async def update(self, db: AsyncSession, db_obj: ModelType, obj_in: Any) -> ModelType: updated_obj = await self.crud.update(db, db_obj=db_obj, obj_in=obj_in) await db.commit() await db.refresh(updated_obj) return updated_obj async def remove(self, db: AsyncSession, *, db_obj: ModelType) -> ModelType: removed_obj = await self.crud.remove(db, db_obj=db_obj) await db.commit() return removed_obj- BaseCRUD에서 처리하지 못한 create, update, remove 함수의 트랜잭션 처리를 Service에서 처리한다
모델 특화 서비스는 이 BaseService를 상속 받아 구현한다
app/core/deps.py
from fastapi import Depends
from sqlalchemy.ext.asyncio import AsyncSession
from database import get_db
from app.domains.country.services import CountryService
from app.domains.country.crud import country_crud
def get_country_service(db: AsyncSession = Depends(get_db)) -> CountryService:
return CountryService(country_crud)
- import가 아닌 의존성 주입의 장점
Depends는 의존성 주입 체인을 만든다- 테스트 용이성
- 라우터를 테스트할 때 Mock 객체로 쉽게 전환이 가능하다.
- 객체 생성의 책임 분리
app/domains/country/crud.py
from models import Country
from app.common.crud.base import CRUDBase
class CRUDCountry(CRUDBase):
def __init__(self, model: type[Country]):
super().__init__(model)
country_crud = CRUDCountry(Country)
- CRUDBase 상속 받아서 구성
- 생성자와
super()를 통해 상위 클래스에 필요 인자를 전달한다.
app/domains/country/services.py
from app.common.services.base import BaseService
from crud import CRUDCountry
class CountryService(BaseService):
def __init__(self, crud: CRUDCountry):
super().__init__(crud)
- BaseService 상속 받아서 구성
- BaseService에서
type으로 클래스로 받지 않고 객체로 받고 있기 때문에 객체로 지정
app/domains/country/schemas.py
from pydantic import BaseModel, ConfigDict
from typing import Optional
class CountryBase(BaseModel):
country_code: str
model_config = ConfigDict(from_attributes=True)
class CountryCreate(CountryBase):
name: str
class CountryUpdate(CountryCreate):
name: Optional[str] = None
country_code: Optional[str] = None
class CountryListItem(CountryBase):
id: int
class CountryResponse(CountryCreate):
id: int
created_at: datetime
updated_at: datetime
CountryBase
country_code필드만 정의한 이유- 가장 기본적인 속성이며, 여러 DTO에서 공통으로 사용되기 때문
name필드의 경우 필요 없는 경우가 있다
이렇게 상위 클래스를 만드는 이유
- 공통 필드 관리
model_config설정 상속
CountryUpdate
- Pydantic은 상속받은 클래스에서 필드를 오버라이딩하는 것을 허용
app/common/schemas/base.py
from pydantic import BaseModel
from typing import Generic, TypeVar, Optional
T = TypeVar("T")
class BaseResponse(BaseModel, Generic[T]):
code: str
message: str
data: Optional[T] = None
- 공통 응답 설정
- Router에서
response_model로 설정한 Pydantic 모델과 Router 메서드의 반환값이 다른 경우 서버 오류를 반환 - 따라서, 상태코드, 메시지 등과 함께 응답하기 위해선 공통 응답을 설정해야 한다.
- Router에서
app/domains/country/router.py
from fastapi import APIRouter
from fastapi.params import Depends
from sqlalchemy.ext.asyncio import AsyncSession
from typing import List
from app.core.database import get_db
from app.core.deps import get_country_service
from app.common.schemas.base import BaseResponse
from services import CountryService
from schemas import CountryListItem
router = APIRouter(prefix="/countries")
@router.get("/", response_model=BaseResponse[List[CountryListItem]])
async def get_countries(
country_service: CountryService = Depends(get_country_service),
db: AsyncSession = Depends(get_db)
):
countries = await country_service.get_all(db=db)
response = [CountryListItem.model_validate(country) for country in countries]
return {
"code": "200",
"message": "OK",
"data": response
}
model_validate( ) 함수
- Pydantic BaseModel 클래스의 핵심 정적 메서드
- 다양한 형태의 입력 데이터를 Pydantic 모델로 변환하고 유효성을 검사
- Pydantic 모델에
from_attributes=True이 필수적으로 설정되어 있어야 한다. - 작동 원리
- 입력 데이터 수용
- 딕셔너리, JSON, SQLAlchemy 모델 객체 등 다양한 형태
from_attributes=True활성화- 모델의 필드들을 속성 이름으로 매핑하고 Pydantic 객체를 생성
- 속성 기반 변환 후 유효성 검사
- Pydantic 모델에 설정된 타입 힌트 참고
- 입력 데이터 수용
리스트 컴프리헨션(List Comprehension) 문법
[변환된_값 for 요소 in 리스트_또는_이터러블]- 파이썬에서 반복문과 조건문을 사용하여 리스트를 한 줄로 간결하게 생성하는 문법
리스트 컴프리헨션 예시
countries_data = [] for country in countries: converted_country = CountryListItem.model_validate(country) countries_data.append(converted_country)countries_data = [CountryListItem.model_validate(country) for country in countries]
Depend( )
- FastAPI에서는 Depend( )를 통해 의존성을 주입한다
- 이렇게 주입하는 것이 결합도 낮추고 테스트 용이하다.
app/main.py (수정)
from fastapi import FastAPI
from app.core.config import settings
from app.core.middleware import add_middleware
from app.domains.country.router import router as country_router
from app.domains.country.models import Country
from app.domains.holiday.models import Holiday
app = FastAPI(
title="Holiday Keeper API",
description="공휴일 정보 관리 API",
version="1.0.0",
docs_url="/docs",
redoc_url="/redoc"
)
add_middleware(app)
app.include_router(country_router, prefix="/api/v1", tags=["countries"])
@app.get("/")
async def root():
return {"message": "Holiday Keeper API is running!"}
@app.get("/health")
async def health_check():
return {
"status": "healthy",
"service": "holiday-keeper-api",
"version": "1.0.0"
}
@app.get("/test")
async def test():
return {
"message": "테스트 성공",
"data": ["공휴일1", "공휴일2", "공휴일3"]
}
@app.get("/config")
async def get_config():
if not settings.is_dev:
return {"error": "개발 환경 전용 API"}
return {
"project_name": settings.project_name,
"version": settings.version,
"environment": settings.environment,
"server_host": settings.server_host,
"server_port": settings.server_port,
"cors_origins": settings.cors_origins_list,
"log_level": settings.log_level
}
app.include_router()를 통해 router를 등록한다.Model Import 하기
from app.domains.country.models import Country from app.domains.holiday.models import Holiday- Alembic이
autogenerate기능을 위해 사용하는 SQLAlchemy의Base.metadata객체는 모든 모델을 알고 있어야 한다. - SQLAlchemy는 모델들을 자동으로 찾아서 등록하지 않는다
- 따라서 모델이 정의된 Python 모듈을 직접 import해야
Base.metadata에 등록된다.
- Alembic이
Alembic 초기화와 마이그레이션
초기화 (Initialization)
- Alembic을 최초로 설정하는 과정
프로젝트당 딱 한 번만 실행
- 실행 방법
alembic init alembic- 루트 디렉토리에서 명령어 입력
- 결과로 아래의 파일이 생성된다.
alembic/디렉터리alembic.ini
마이그레이션 (Migration)
- 데이터베이스 스키마를 변경하는 과정
- 모델에 필드를 추가하거나, 테이블을 삭제하는 등 실제 데이터베이스 변경 사항이 있을 때마다 하는 과정
alembic revision명령어로 스크립트를 생성alembic upgrade명령어로 그 변경 사항을 적용
alembic.ini (수정)
# database URL. This is consumed by the user-maintained env.py script only.
# other means of configuring database URLs may be customized within the env.py
# file.
sqlalchemy.url = driver://user:pass@localhost/dbname
[post_write_hooks]
- SQLAlchemy 데이터베이스 URL 수정
alembic.ini파일에서sqlalchemy.url =부분을 찾는다.# database URL. This is consumed by the user-maintained env.py script only. # other means of configuring database URLs may be customized within the env.py # file. sqlalchemy.url = mysql+pymysql://root:1234@127.0.0.1:3306/hk_fast_db?charset=utf8mb4 [post_write_hooks]
alembic/env.py (수정)
from logging.config import fileConfig
from sqlalchemy import engine_from_config
from sqlalchemy import pool
from alembic import context
from app.core.config import settings
# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.
config = context.config
# Interpret the config file for Python logging.
# This line sets up loggers basically.
if config.config_file_name is not None:
fileConfig(config.config_file_name)
# add your model's MetaData object here
# for 'autogenerate' support
# from myapp import mymodel
# target_metadata = mymodel.Base.metadata
from app.common.models.base import Base
target_metadata = Base.metadata
# other values from the config, defined by the needs of env.py,
# can be acquired:
# my_important_option = config.get_main_option("my_important_option")
# ... etc.
def run_migrations_offline() -> None:
"""Run migrations in 'offline' mode.
This configures the context with just a URL
and not an Engine, though an Engine is acceptable
here as well. By skipping the Engine creation
we don't even need a DBAPI to be available.
Calls to context.execute() here emit the given string to the
script output.
"""
url = config.get_main_option("sqlalchemy.url")
context.configure(
url=url,
target_metadata=target_metadata,
literal_binds=True,
dialect_opts={"paramstyle": "named"},
)
with context.begin_transaction():
context.run_migrations()
def run_migrations_online() -> None:
"""Run migrations in 'online' mode.
In this scenario we need to create an Engine
and associate a connection with the context.
"""
connectable = engine_from_config(
{'sqlalchemy.url': settings.sync_database_url},
prefix="sqlalchemy.",
poolclass=pool.NullPool,
)
with connectable.connect() as connection:
context.configure(
connection=connection, target_metadata=target_metadata
)
with context.begin_transaction():
context.run_migrations()
if context.is_offline_mode():
run_migrations_offline()
else:
run_migrations_online()
Alembic이 SQLAlchemy 모델을 인식하도록 수정해야 한다.
target_metadata = None 설정
target_metadata = None을 SQLAlchemy의 Base 클래스를 임포트로 변경# add your model's MetaData object here # for 'autogenerate' support # from myapp import mymodel # target_metadata = mymodel.Base.metadata from app.common.models.base import Base from app.domains.country.models import Country from app.domains.holiday.models import Holiday target_metadata = Base.metadata # other values from the config, defined by the needs of env.py,- SQLAlchemy의 Base 클래스를 임포트해줘야 인식 가능
run_migrations_online( )
- Alembic이 데이터베이스에 직접 연결하여 마이그레이션 스크립트를 실행
주로 개발 단계나 실제 운영 서버에서 사용
alembic upgrade head명령어를 실행하면 이 함수가 호출된다(우리가 사용할 것)
run_migrations_offline( )
- Alembic이 데이터베이스에 직접 연결하지 않고, SQL 문을 문자열로 생성하여 출력하는 방식
- 데이터베이스 접근이 불가능하거나 제한적인 환경(예: 배포 파이프라인)에서 사용
- 이렇게 생성된 SQL 파일을 나중에 직접 데이터베이스에 적용할 수 있다
alembic upgrade head --sql명령어를 실행하면 이 함수가 호출된다
settings.sync_database_url 사용하기
# 수정 전 def run_migrations_online() -> None: # ... connectable = engine_from_config( config.get_section(config.config_ini_section, {}), prefix='sqlalchemy.', poolclass=pool.NullPool, ) with connectable.connect() as connection: context.configure( connection=connection, target_metadata=target_metadata ) with context.begin_transaction(): context.run_migrations() # 수정 후 from app.core.config import settings def run_migrations_online() -> None: # ... connectable = engine_from_config( {'sqlalchemy.url': settings.sync_database_url}, prefix="sqlalchemy.", poolclass=pool.NullPool, ) with connectable.connect() as connection: context.configure( connection=connection, target_metadata=target_metadata ) with context.begin_transaction(): context.run_migrations()- 보안 강화를 위해서
alembic.ini에서sqlalchemy.url을 읽어와서 데이터베이스에 연결하지 않고 Pydantic BaseSettings가 관리하는settings.sync_database_url에서 직접 가져온다. - 보안 강화
sqlalchemy.url을.ini파일에 직접 노출하지 않고,.env파일을 통해 환경 변수로 관리
- 일관성 있게 환경 변수를 통해 관리한다
- 보안 강화를 위해서
alembic.ini 파일 수정하기
# database URL. This is consumed by the user-maintained env.py script only. # other means of configuring database URLs may be customized within the env.py # file. # sqlalchemy.url = driver://user:pass@localhost/dbname- 이제
env.py에서settings.sync_database_url을 직접 사용하도록 변경- Alembic은 이제
alembic.ini에 설정된sqlalchemy.url값을 사용하지 않는다
- Alembic은 이제
- 주석 처리하거나 제거
- 이제
마이그레이션 스크립트 적용하기
Alembic에게 마이그레이션 파일 요청
alembic revision --autogenerate -m "Create Country Model table"-autogenerateenv.py의target_metadata를 보고 현재 데이터베이스 스키마와 비교한다
-m "..."- 마이그레이션에 대한 설명을 추가
- 결과
alembic/versions/폴더에 테이블 생성 코드가 담긴 파이썬 파일이 생성
마이그레이션 실행
alembic upgrade headupgrade head- 가장 최근에 생성된 마이그레이션 스크립트까지 모두 실행
스키마 변경 시 마이그레이션
- 모델 변경 후 아래 과정을 적용해야 한다
마이그레이션 스크립트 생성
alembic revision --autogenerate -m "Add population to Country model"--autogenerate로 DB의 스키마와 SQLAlchemy 모델을 비교하여 변경 분석- 결과로
alembic/versions/디렉토리에 새로운 버전 파일이 생성된다.
마이그레이션 적용하기
alembic upgrade head- 아직 데이터베이스에 적용되지 않은 마이그레이션 스크립트 중 가장 최신 버전(
head)까지 모두 실행
- 아직 데이터베이스에 적용되지 않은 마이그레이션 스크립트 중 가장 최신 버전(
app/domains/country/router.py (수정)
from fastapi import APIRouter
from fastapi.params import Depends
from sqlalchemy.ext.asyncio import AsyncSession
from typing import List
from app.core.database import get_db
from app.core.deps import get_country_service
from app.common.schemas.base import BaseResponse
from app.domains.country.services import CountryService
from app.domains.country.schemas import CountryListItem, CountryResponse, CountryUpdate, CountryCreate
router = APIRouter(prefix="/countries")
@router.get("/", response_model=BaseResponse[List[CountryListItem]])
async def get_countries(
country_service: CountryService = Depends(get_country_service),
db: AsyncSession = Depends(get_db)
):
countries = await country_service.get_all(db=db)
response = [CountryListItem.model_validate(country) for country in countries]
return {
"code": "200",
"message": "목록 조회 성공",
"data": response
}
@router.get("/{country_id}", response_model=BaseResponse[CountryResponse])
async def get_country(
country_id: int,
country_service: CountryService = Depends(get_country_service),
db: AsyncSession = Depends(get_db)
):
country = await country_service.get_by_id(db, country_id)
resource = CountryResponse.model_validate(country)
return {
"code": "200",
"message": f"{country_id}번 국가 조회 성공",
"data": resource
}
@router.post("/")
async def create_country(
country_create_dto: CountryCreate,
country_service: CountryService = Depends(get_country_service),
db: AsyncSession = Depends(get_db)
):
await country_service.create(db, obj_in=country_create_dto)
return {
"code": "201",
"message": "새로운 국가 생성 성공",
"data": None
}
@router.put("/{country_id}")
async def update_country(
country_id: int,
country_update_dto: CountryUpdate,
country_service: CountryService = Depends(get_country_service),
db: AsyncSession = Depends(get_db)
):
country = await country_service.get_by_id(db, country_id)
await country_service.update(db, db_obj=country, obj_in=country_update_dto)
return {
"code": "204",
"message": f"{country_id}번 국가 수정 성공",
"data": None
}
@router.delete("/{country_id}")
async def delete_country(
country_id: int,
country_service: CountryService = Depends(get_country_service),
db: AsyncSession = Depends(get_db)
):
country = await country_service.get_by_id(db, country_id)
await country_service.remove(db, db_obj=country)
return {
"code": "204",
"message": f"{country_id}번 국가 제거 성공",
"data": None
}
2.4. 공통 응답 및 에러 처리
app/common/schemas/base.py (수정)
from pydantic import BaseModel, Field
from typing import Generic, TypeVar, Optional
from datetime import datetime
T = TypeVar("T")
class BaseResponse(BaseModel, Generic[T]):
code: str = Field("200", description="응답 코드")
message: str = Field("OK", description="응답 메시지")
data: Optional[T] = Field(None, description="응답 데이터")
class ErrorResponse(BaseModel):
timestamp: datetime = Field(default_factory=datetime.now, description="에러 발생 일시")
code: str = Field("500", description="에러 코드")
message: str = Field("Internal Server Error", description="에러 메시지")
data: None = Field(None, description="응답 데이터")
details: Optional[Dict[str, Any]] = Field(None, description="에러 상세 정보")
BaseResponse
T = TypeVar("T") class BaseResponse(BaseModel, Generic[T]): code: str = Field("200", description="응답 코드") message: str = Field("Success", description="응답 메시지") data: Optional[T] = Field(None, description="응답 데이터")- 정상적인 API 응답을 위한 기본 스키마
data- 제네릭 타입
T로 정의되어 있어, 다양한 데이터 모델을 담을 수 있다
- 제네릭 타입
Field( ) 함수 사용
Field()함수로 메타데이터 추가 (Swagger API)- 조건을 설정하여 더 구체적으로 유효성 검사 가능
- Pydantic BaseModel로 타입 검사는 기본적으로 수행한다
Field(...)로 기본값 없이 필수 필드로 지정
ErrorResponse
class ErrorResponse(BaseModel): timestamp: datetime = Field(default_factory=datetime.now, description="에러 발생 일시") code: str = Field("500", description="에러 코드") message: str = Field("Internal Server Error", description="에러 메시지") data: None = Field(None, description="응답 데이터") details: Optional[Dict[str, Any]] = Field(None, description="에러 상세 정보")- 에러 발생 시 API 응답을 위한 스키마
data필드는 항상Nonedefault_factoryField함수에서 사용- 인수가 없는 함수를 받아서 모델이 생성될 때마다 해당 함수를 호출하여 기본값 생성
app/common/schemas/pagination.py
from pydantic import BaseModel, Field
from typing import List, TypeVar, Generic
T = TypeVar("T")
class PaginationMeta(BaseModel):
total_elements: int = Field(..., description="전체 항목 수")
page: int = Field(..., description="현재 페이지 번호")
page_size: int = Field(..., description="페이지 당 항목 수")
has_next: bool = Field(..., description="다음 페이지 존재 여부")
has_prev: bool = Field(..., description="이전 페이지 존재 여부")
class PageResponse(BaseModel, Generic[T]):
data: List[T] = Field(..., description="응답 데이터")
meta: PaginationMeta = Field(..., description="페이지 메타데이터")
PaginationMeta
class PaginationMeta(BaseModel): total_elements: int = Field(..., description="전체 항목 수") page: int = Field(..., description="현재 페이지 번호") page_size: int = Field(..., description="페이지 당 항목 수") has_next: bool = Field(..., description="다음 페이지 존재 여부") has_prev: bool = Field(..., description="이전 페이지 존재 여부")- 페이지네이션 메타데이터를 위한 스키마
data- 제네릭 타입
T로 정의되어 있어, 다양한 데이터 모델을 담을 수 있다
- 제네릭 타입
PageResponse
T = TypeVar("T") class PageResponse(BaseModel, Generic[T]): data: List[T] = Field(..., description="응답 데이터") meta: PaginationMeta = Field(..., description="페이지 메타데이터")- 페이지네이션 응답 스키마
data- 제네릭 타입
T로 정의되어 있어, 다양한 데이터 모델을 담을 수 있다
- 제네릭 타입
app/domains/country/schemas.py (수정)
from datetime import datetime
from pydantic import BaseModel, ConfigDict, Field
from typing import Optional
class CountryBase(BaseModel):
country_code: str = Field(..., description="국가 코드", examples=["KR", "US"])
model_config = ConfigDict(from_attributes=True)
class CountryCreate(CountryBase):
name: str = Field(..., description="국가 영문명", examples=["Australia"])
class CountryUpdate(CountryCreate):
name: Optional[str] = Field(None, description="국가 영문명", examples=["Australia"])
country_code: Optional[str] = Field(None, description="국가 코드", examples=["KR", "US"])
class CountryListItem(CountryBase):
id: int = Field(..., description="ID", examples=[4])
class CountryResponse(CountryCreate):
id: int = Field(..., description="ID", examples=[4])
created_at: datetime = Field(..., description="생성 일시", examples=[datetime.now()])
updated_at: datetime = Field(..., description="수정 일시", examples=[datetime.now()])
- Swagger UI 메타데이터 추가를 위해
Field()함수를 활용
app/core/exceptions.py
from __future__ import annotations
from fastapi import status
from dataclasses import dataclass, field
from typing import Any, Dict
from enum import Enum
@dataclass(frozen=True)
class ErrorCode:
code: str
http_status: int
message: str
class CustomException(Exception):
def __init__(self, error_type: ErrorType, message: str = None, details: Dict[str, Any] = None):
self.error_code = error_type.value
self.details = details if details else {}
self.message = message or self.error_code.message
super().__init__(self.message)
class ErrorType(Enum):
INTERNAL_SERVER_ERROR = ErrorCode(
code="500", http_status=status.HTTP_500_INTERNAL_SERVER_ERROR, message="서버에 오류가 발생했습니다"
)
NOT_FOUND = ErrorCode(
code="404", http_status=status.HTTP_404_NOT_FOUND, message="요청하신 리소스를 찾을 수 없습니다"
)
BAD_REQUEST = ErrorCode(
code="400", http_status=status.HTTP_400_BAD_REQUEST, message="Bad Request"
)
VALIDATION_ERROR = ErrorCode(
code="422", http_status=status.HTTP_422_UNPROCESSABLE_ENTITY, message="Validation Error"
)
class ResourceNotFoundException(CustomException):
def __init__(self, *, resource_type: str, resource_id: Any):
message = f"ID {resource_id}번 {resource_type}을 찾을 수 없습니다."
details = {"resource_id": resource_id, "resource_type": resource_type}
super().__init__(ErrorType.NOT_FOUND, message=message, details=details)
from __future__ import annotations
- Python 3.7+의 기능 :
from __future__ import annotations - 타입 힌트의 순환 참조 문제 해결
- 모든 타입 힌트를 문자열로 저장
- 런타임 성능 향상
- 지연 평가 (lazy evaluation)
- 타입 힌트가 실제로 필요할 때만 평가
- import 시간 단축
- 타입 체킹용 모듈들을 즉시 로드하지 않는다
- Python 3.7+의 기능 :
ErrorCode
@dataclass(frozen=True) class ErrorCode: code: str http_status: int message: str- 표준화된 에러 코드를 정의하는 클래스
code: 클라이언트에 노출될 고유 에러 코드http_status: HTTP 상태 코드message: 기본 에러 메시지- @dataclass
- Python 3.7부터 도입된 데코레이터
- 데이터를 캡슐화하고 속성을 가진 객체를 쉽게 만들 때 사용
- 데이터를 저장하는 클래스를 쉽게 만들 수 있도록 도와준다
__init__,__repr__,__eq__등 자동 생성
frozen=True- 불변(immutable) 객체를 만드는 기능
- 생성 후 필드 값 변경 불가능
CustomException 클래스
class CustomException(Exception): def __init__(self, error_type: ErrorType, message: str = None, details: Dict[str, Any] = None): self.error_code = error_type.value self.details = details if details else {} self.message = message or self.error_code.message super().__init__(self.message) class CustomException(Exception): def __init__(self, error_code: ErrorCode, message: str = None): super().__init__(message or error_code.message) self.error_code = error_code if message: self.error_code.message = message- 커스텀 에러 처리를 위한 예외 클래스
- FastAPI가 제공하는 기본 HTTPException을 활용해도 되지만 상세한 설정이 가능
error_typeEnum을 전달하여 예외 발생 시 상세 정보를 전달CustomException 객체 생성 시 특정 메세지나 Detail을 포함한다면, 자체 필드로 그 값을 받는다
def __init__(self, error_type: ErrorType, message: str = None, details: Dict[str, Any] = None): self.error_code = error_type.value self.details = details if details else {} self.message = message or self.error_code.message- details 예시
"details": {"field": "email", "reason": "invalid_format"}"details": {"resource": "/users/123", "required_permission": "admin"}"details": {"external_service": "payment_gateway", "error_code": "PG-503"}
- 커스텀 에러 처리를 위한 예외 클래스
ErrorType 클래스
class ErrorType(Enum): INTERNAL_SERVER_ERROR = ErrorCode( code="500", http_status=status.HTTP_500_INTERNAL_SERVER_ERROR, message="서버에 오류가 발생했습니다" ) NOT_FOUND = ErrorCode( code="404", http_status=status.HTTP_404_NOT_FOUND, message="요청하신 리소스를 찾을 수 없습니다" ) BAD_REQUEST = ErrorCode( code="400", http_status=status.HTTP_400_BAD_REQUEST, message="Bad Request" ) VALIDATION_ERROR = ErrorCode( code="422", http_status=status.HTTP_422_UNPROCESSABLE_ENTITY, message="Validation Error" )- Enum을 활용하여 실수하지 않도록 설정
ResourceNotFoundException 클래스
class ResourceNotFoundException(CustomException): def __init__(self, *, resource_type: str, resource_id: Any): message = f"ID {resource_id}번 {resource_type}을 찾을 수 없습니다." details = {"resource_id": resource_id, "resource_type": resource_type} super().__init__(ErrorType.NOT_FOUND, message=message, details=details)- 쉽게 에러 처리를 할 수 있도록 에러 클래스를 미리 만들어 놓고 처리한다.
app/domains/country/exceptions.py
from app.core.exceptions import CustomException, ErrorType
from typing import Any
class CountryNotFoundException(CustomException):
def __init__(self, *, resource_id: Any):
message = f"ID {resource_id}번 국가를 찾을 수 없습니다."
details = {"resource_id": resource_id, "resource_type": "Country"}
super().__init__(ErrorType.NOT_FOUND, message=message, details=details)
- Country 도메인 특화 에러 모음
app/core/exception_handlers.py
import logging
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
from fastapi.exceptions import RequestValidationError
from app.core.exceptions import CustomException, ErrorType
from app.common.schemas.base import ErrorResponse
logger = logging.getLogger(__name__)
def add_exception_handler(app: FastAPI):
@app.exception_handler(CustomException)
async def custom_exception_handler(request: Request, exc: CustomException):
logger.error(f"CustomException: {exc.message} - Code: {exc.error_code.code}")
response = ErrorResponse(
code=exc.error_code.code,
message=exc.message,
details = exc.details if exc.details else None
)
return JSONResponse(
status_code=exc.error_code.http_status,
content=response.model_dump(mode="json")
)
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request: Request, exc: RequestValidationError):
logger.error(f"Validation error: {exc.errors()}")
error_messages = []
for error in exc.errors():
field = " -> ".join(str(x) for x in error["loc"][1:])
message = error["msg"]
error_messages.append(f"{field}: {message}")
error_code = ErrorType.VALIDATION_ERROR.value
response = ErrorResponse(
message="; ".join(error_messages),
code=error_code.code
)
return JSONResponse(
status_code=error_code.http_status,
content=response.model_dump(mode="json")
)
@app.exception_handler(404)
async def not_found_handler(request: Request, exc):
error_code = ErrorType.NOT_FOUND.value
response = ErrorResponse(
message=error_code.message,
code=error_code.code
)
return JSONResponse(
status_code=error_code.http_status,
content=response.model_dump(mode="json")
)
@app.exception_handler(500)
async def internal_error_handler(request: Request, exc):
logger.error(f"Internal server error: {str(exc)}")
error_code = ErrorType.INTERNAL_SERVER_ERROR.value
response = ErrorResponse(
message=error_code.message,
code=error_code.code
)
return JSONResponse(
status_code=error_code.http_status,
content=response.model_dump(mode="json")
)
JSONResponse의 필요성
- FastAPI는 기본적으로 Pydantic 모델이나 Python 딕셔너리를 반환하면 자동으로 JSON 응답으로 변환한다
- 하지만 예외가 발생한 상황은 개발자가 직접 HTTP 상태 코드와 응답 바디를 명시적으로 제어해야 한다
model_dump( )
- Pydantic V2에서 Pydantic 모델 객체를 파이썬 딕셔너리로 변환하는 메서드
- 예외 발생 상황이므로 Pydantic 모델 객체를 HTTP 응답 본문으로 사용하기 위해 JSON 형식으로 변환 가능한 딕셔너리로 만들어야 한다
timestamp의 경우 직렬화 문제가 있을 수 있기 때문에model_dump(mode=”json”)으로 설정
custom_exception_handler( )
- CustomException 클래스의 자체
message와details필드- 이를 통해
details필드를 통해 동적으로 예외 상황의 상세 정보를 추가할 수 있다 - 유효성 검사 실패 시
{"field": "email", "reason": "invalid_format"}
- 외부 API 호출 실패 시 해당 서비스의 에러 코드
{"external_api": "payment_gateway", "error_code": "PG-400"}
- 사용자를 위한 것이 아닌 오류 분석 시스템 구축 시 그것을 위한 메시지
- 이를 통해
- CustomException 클래스의 자체
validation_exception_handler( )
FastAPI가 요청을 받으면, Pydantic 모델의 규칙(ex :
max_length,gt)에 따라 유효성 검사 수행유효성 검사에 실패 시 RequestValidationError라는 특별한 예외 객체가 자동으로 발생
예외 객체에는 오류 리스트가 담겨 있다.
[ { "loc": ["body", "name"], "msg": "Input should be a valid string", "type": "string_type" }, { "loc": ["body", "price"], "msg": "Input should be greater than 0", "type": "greater_than" } ]
validation_exception_handler함수는 이 리스트를 순회loc(오류 위치)와msg(오류 메시지)를 뽑아내서 가공 후 클라이언트에 전달name: Input should be a valid string; price: Input should be greater than 0
app/main.py
from app.core.exception_handlers import add_exception_handler
from app.core.exceptions import CustomException, ErrorType
add_exception_handler(app)
@app.get("/test/error")
async def test_error():
raise CustomException(ErrorType.NOT_FOUND)
add_exception_handler(app)로 전역 에러 처리를 등록하면 전역 에러 처리가 가능해진다.예외 처리 테스트 API
@app.get("/test/error") async def test_error(): raise CustomException(ErrorType.NOT_FOUND)
app/common/crud/base.py (수정)
from typing import Type, TypeVar, Any, Optional, List
from sqlalchemy import select, func
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import DeclarativeBase
ModelType = TypeVar("ModelType", bound=DeclarativeBase)
class CRUDBase:
def __init__(self, model: Type[ModelType]):
self.model = model
async def get(self, db: AsyncSession, obj_id: Any) -> Optional[ModelType]:
return await db.get(self.model, obj_id)
async def get_multi(self, db: AsyncSession, *, skip: int = 0, limit: int = 100) -> List[ModelType]:
stmt = select(self.model).offset(skip).limit(limit)
result = await db.execute(stmt)
return result.scalars().all()
async def get_all(self, db: AsyncSession) -> List[ModelType]:
stmt = select(self.model)
results = await db.execute(stmt)
return results.scalars().all()
async def count(self, db: AsyncSession) -> int:
stmt = select(func.count()).select_from(self.model)
result = await db.execute(stmt)
return result.scalar_one()
async def create(self, db: AsyncSession, *, obj_in: Any) -> ModelType:
db_obj = self.model(**obj_in.model_dump())
db.add(db_obj)
return db_obj
async def update(self, db: AsyncSession, db_obj: ModelType, obj_in: Any) -> ModelType:
update_data = obj_in.model_dump(exclude_unset=True)
for key, value in update_data.items():
setattr(db_obj, key, value)
db.add(db_obj)
return db_obj
async def remove(self, db: AsyncSession, *, db_obj: ModelType) -> ModelType:
await db.delete(db_obj)
return db_obj
- 불필요한 메서드 제거
get_by_field,get_by_fields,get_multi_by_fields,count_by_fields
app/common/services/base.py (수정)
from typing import TypeVar, Any, Optional, List
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.exc import IntegrityError
from app.common.crud.base import CRUDBase, ModelType
from app.core.exceptions import ResourceNotFoundException, ResourceConflictException
CRUDType = TypeVar("CRUDType", bound=CRUDBase)
class BaseService:
def __init__(self, crud: CRUDType):
self.crud : CRUDType = crud
async def get_by_id(self, db: AsyncSession, obj_id: Any) -> ModelType:
obj = await self.crud.get(db, obj_id)
if not obj:
raise ResourceNotFoundException(
resource_type=self.crud.model.__name__,
resource_id=obj_id
)
return obj
async def get_multi(self, db: AsyncSession, *, skip: int = 0, limit: int = 100) -> List[ModelType]:
return await self.crud.get_multi(db, skip=skip, limit=limit)
async def get_all(self, db: AsyncSession) -> List[ModelType]:
return await self.crud.get_all(db)
async def count(self, db: AsyncSession) -> int:
return await self.crud.count(db)
async def create(self, db: AsyncSession, *, obj_in: Any) -> ModelType:
try:
db_obj = await self.crud.create(db, obj_in=obj_in)
await db.commit()
await db.refresh(db_obj)
return db_obj
except IntegrityError:
await db.rollback()
raise ResourceConflictException(resource_type=self.crud.model.__name__, operation="생성")
async def update(self, db: AsyncSession, db_obj: ModelType, obj_in: Any) -> ModelType:
try:
updated_obj = await self.crud.update(db, db_obj=db_obj, obj_in=obj_in)
await db.commit()
await db.refresh(updated_obj)
return updated_obj
except IntegrityError:
await db.rollback()
raise ResourceConflictException(resource_type=self.crud.model.__name__, operation="수정")
async def remove(self, db: AsyncSession, *, db_obj: ModelType) -> ModelType:
removed_obj = await self.crud.remove(db, db_obj=db_obj)
await db.commit()
return removed_obj
제네릭 타입 오류 수정
CRUDType = TypeVar("CRUDType", bound=CRUDBase) class BaseService: def __init__(self, crud: CRUDType): self.crud : CRUDType = crud- 이렇게 해야지 IDE가 정확한 타입을 인식할 수 있다
불필요한 메서드 제거
get_by_field,get_by_fields,get_multi_by_fields,count_by_fields
get_by_id 함수
수정 전
async def get_by_id(self, db: AsyncSession, obj_id: Any) -> Optional[ModelType]: return await self.crud.get(db, obj_id)수정 후
async def get_by_id(self, db: AsyncSession, obj_id: Any) -> ModelType: obj = await self.crud.get(db, obj_id) if not obj: raise ResourceNotFoundException( resource_type=self.crud.model.__name__, resource_id=obj_id ) return obj- 반환 타입에서
Optional을 제거하여 객체를 반드시 반환하거나 예외를 발생시키도록 변경 ResourceNotFoundException에러 처리self.crud.model.__name__CRUDBase인스턴스의model속성으로 전달받은 모델 클래스의 클래스 이름을 문자열로 받는다
- 반환 타입에서
app/core/exceptions.py (수정)
class ErrorType(Enum):
INTERNAL_SERVER_ERROR = ErrorCode(
code="500", http_status=status.HTTP_500_INTERNAL_SERVER_ERROR, message="서버에 오류가 발생했습니다"
)
NOT_FOUND = ErrorCode(
code="404", http_status=status.HTTP_404_NOT_FOUND, message="요청하신 리소스를 찾을 수 없습니다"
)
BAD_REQUEST = ErrorCode(
code="400", http_status=status.HTTP_400_BAD_REQUEST, message="Bad Request"
)
VALIDATION_ERROR = ErrorCode(
code="422", http_status=status.HTTP_422_UNPROCESSABLE_ENTITY, message="Validation Error"
)
CONFLICT = ErrorCode(
code="409", http_status=status.HTTP_409_CONFLICT, message="리소스 충돌이 발생했습니다"
)
class ResourceConflictException(CustomException):
def __init__(self, *, resource_type: str, operation: str = "처리"):
message = f"{resource_type} {operation} 중 중복 데이터가 발견되었습니다."
details = {"resource_type": resource_type, "operation": operation}
super().__init__(ErrorType.CONFLICT, message=message, details=details)
ErrorType에 CONFLICT 타입 추가
CONFLICT = ErrorCode( code="409", http_status=status.HTTP_409_CONFLICT, message="리소스 충돌이 발생했습니다" )ResourceConflictException 공용 예외 클래스 추가
class ResourceConflictException(CustomException): def __init__(self, *, resource_type: str, operation: str = "처리"): message = f"{resource_type} {operation} 중 중복 데이터가 발견되었습니다." details = {"resource_type": resource_type, "operation": operation} super().__init__(ErrorType.CONFLICT, message=message, details=details)
app/domains/country/crud.py (수정)
from sqlalchemy.ext.asyncio import AsyncSession
from typing import Optional
from sqlalchemy import select
from app.domains.country.models import Country
from app.common.crud.base import CRUDBase
class CRUDCountry(CRUDBase):
def __init__(self, model: type[Country]):
super().__init__(model)
async def get_by_country_code(self, db: AsyncSession, country_code: str) -> Optional[Country]:
stmt = select(self.model).where(self.model.country_code == country_code)
result = await db.execute(stmt)
return result.scalars().one_or_none()
country_crud = CRUDCountry(Country)
get_by_country_code 함수 생성
async def get_by_country_code(self, db: AsyncSession, country_code: str) -> Optional[Country]: stmt = select(self.model).where(self.model.country_code == country_code) result = await db.execute(stmt) return result.scalars().one_or_none()Optional로 예외 처리가 가능하게 반환한다.
app/domains/country/exceptions.py (수정)
from app.core.exceptions import CustomException, ErrorType
from typing import Any
class CountryNotFoundException(CustomException):
def __init__(self, *, field_name: str, value: Any):
message = f"{field_name} '{value}'에 해당하는 국가를 찾을 수 없습니다."
details = {"field_name": field_name, "value": value, "resource_type": "Country"}
super().__init__(ErrorType.NOT_FOUND, message=message, details=details)
class CountryConflictException(CustomException):
def __init__(self, *, operation: str = "처리"):
message = f"Country {operation} 중 중복 데이터가 발견되었습니다."
details = {"resource_type": "Country", "operation": operation}
super().__init__(ErrorType.CONFLICT, message=message, details=details)
CountryConflictException 예외 클래스 추가
class CountryConflictException(CustomException): def __init__(self, *, operation: str = "처리"): message = f"Country {operation} 중 중복 데이터가 발견되었습니다." details = {"resource_type": "Country", "operation": operation} super().__init__(ErrorType.CONFLICT, message=message, details=details)
app/domains/country/services.py (수정)
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.exc import IntegrityError
from app.common.services.base import BaseService
from app.domains.country.crud import CRUDCountry
from app.domains.country.models import Country
from app.domains.country.schemas import CountryCreate, CountryUpdate
from app.domains.country.exceptions import CountryNotFoundException, CountryConflictException
class CountryService(BaseService):
def __init__(self, crud: CRUDCountry):
super().__init__(crud)
async def get_by_country_code(self, db: AsyncSession, country_code: str) -> Country:
country = await self.crud.get_by_country_code(db, country_code)
if not country:
raise CountryNotFoundException(field_name="국가 코드", value=country_code)
return country
async def create_country(self, db: AsyncSession, *, country_create_dto: CountryCreate) -> Country:
try:
country = await self.crud.create(db, obj_in=country_create_dto)
await db.commit()
await db.refresh(country)
return country
except IntegrityError:
await db.rollback()
raise CountryConflictException(operation="생성")
async def update(self, db: AsyncSession, db_country: Country, country_update_dto: CountryUpdate) -> Country:
try:
updated_country = await self.crud.update(db, db_obj=db_country, obj_in=country_update_dto)
await db.commit()
await db.refresh(updated_country)
return updated_country
except IntegrityError:
await db.rollback()
raise CountryConflictException(operation="수정")
get_by_country_code 함수 생성
async def get_by_country_code(self, db: AsyncSession, country_code: str) -> Country: country = await self.crud.get_by_country_code(db, country_code) if not country: raise CountryNotFoundException(resource_id=country_code) return country- Country 도메인 특화 에러인
CountryNotFoundException에러 처리
- Country 도메인 특화 에러인
create_country 함수 생성
async def create_country(self, db: AsyncSession, *, country_create_dto: CountryCreate) -> Country: try: country = await self.crud.create(db, obj_in=country_create_dto) await db.commit() await db.refresh(country) return country except IntegrityError: await db.rollback() raise CountryConflictException(operation="생성")- Country 도메인 특화 에러인
CountryConflictException에러 처리
- Country 도메인 특화 에러인
update 함수 생성
async def update(self, db: AsyncSession, db_country: Country, country_update_dto: CountryUpdate) -> Country: try: updated_country = await self.crud.update(db, db_obj=db_country, obj_in=country_update_dto) await db.commit() await db.refresh(updated_country) return updated_country except IntegrityError: await db.rollback() raise CountryConflictException(operation="수정")- Country 도메인 특화 에러인
CountryConflictException에러 처리 - 이 경우 공용 서비스의 update 함수를 오버라이딩한다.
- 단, router에서 키워드 인수로 호출 시 인수 이름 주의
- Country 도메인 특화 에러인
app/domains/country/router.py (수정)
@router.get("/code/{country_code}", response_model=BaseResponse[CountryResponse])
async def get_country_by_country_code(
country_code: str,
country_service: CountryService = Depends(get_country_service),
db: AsyncSession = Depends(get_db)
):
country = await country_service.get_by_country_code(db, country_code)
resource = CountryResponse.model_validate(country)
return {
"code": "200",
"message": f"국가 코드 {country_code}의 국가 조회 성공",
"data": resource
}
@router.post("/")
async def create_country(
country_create_dto: CountryCreate,
country_service: CountryService = Depends(get_country_service),
db: AsyncSession = Depends(get_db)
):
await country_service.create(db, obj_in=country_create_dto)
return {
"code": "201",
"message": "새로운 국가 생성 성공",
"data": None
}
@router.put("/{country_id}")
async def update_country(
country_id: int,
country_update_dto: CountryUpdate,
country_service: CountryService = Depends(get_country_service),
db: AsyncSession = Depends(get_db)
):
country = await country_service.get_by_id(db, country_id)
await country_service.update(db, country, country_update_dto)
return {
"code": "204",
"message": f"{country_id}번 국가 수정 성공",
"data": None
}
- 테스트 용 API 추가
- update 함수 ⇒ Country 서비스 특화
update()호출 - create 함수 ⇒ 공용 서비스
create()호출 (작동 테스트)
app/domains/country/models.py (수정)
from sqlalchemy import Column, Integer, String, DateTime
from sqlalchemy.orm import relationship
from app.common.models.base import BaseModel
class Country(BaseModel):
__tablename__ = 'countries'
country_code = Column(String(10), unique=True, nullable=False)
name = Column(String(100), nullable=False)
holidays = relationship(
"Holiday",
back_populates="country",
cascade="all, delete-orphan",
)
- cascade 추가
- 외래 키 제약으로 인해 삭제가 실패 대처
2.5. Celery 스케줄링 및 외부 API 호출 구현하기
app/external/api_client.py
import httpx
from app.core.config import settings
from app.domains.holiday.schemas import (
AvailableCountriesApiResponse,
PublicHolidaysApiResponse,
)
class HolidayAPIClient:
def __init__(self):
self.holidays_url = settings.nager_api_holidays_url
self.countries_url = settings.nager_api_countries_url
async def fetch_countries_data(self) -> list[AvailableCountriesApiResponse]:
async with httpx.AsyncClient() as client:
response = await client.get(self.countries_url)
response.raise_for_status()
return [AvailableCountriesApiResponse(**item) for item in response.json()]
async def fetch_holidays_data(
self, year: int, country_code: str
) -> list[PublicHolidaysApiResponse]:
url = f"{self.holidays_url}/{year}/{country_code}"
async with httpx.AsyncClient() as client:
response = await client.get(url)
response.raise_for_status()
return [PublicHolidaysApiResponse(**item) for item in response.json()]
holiday_api_client = HolidayAPIClient()
httpx.AsyncClient( )
async with httpx.AsyncClient() as client: response = await client.get(self.countries_url) response.raise_for_status() return [AvailableCountriesApiResponse(**item) for item in response.json()]httpx- Python의 HTTP 클라이언트 라이브러리
httpx.AsyncClient()- 비동기 환경에서 HTTP 요청을 보낼 때 사용되는 객체
- 비동기 HTTP 요청을 보낼 수 있는 클라이언트 세션을 생성
- 이 함수는
async with과 함께 사용될 때 가장 효율적이다.- 시작될 때 네트워크 연결이 설정되고, 블록이 끝날 때 연결이 자동으로 해제되어 자원 누수 방지
AsyncClient().get(url)AsyncClient인스턴스의 메서드로, HTTP GET 요청 보내기- 요청이 완료될 때까지 실행을 일시 중단했다가 응답을 받으면 계속 진행
- 다양한 인자
- 인수로 **
url**을 넣어서 GET 요청을 보낸다. params- URL 쿼리 파라미터를 딕셔너리 형태로 전달
headers- 요청 헤더를 딕셔너리 형태로 전달
timeout- 요청 타임아웃 시간을 설정.
follow_redirects- 리다이렉션을 자동으로 따를지 여부를 설정
- 인수로 **
httpx.Response
- 모든
client.method()함수는httpx.Response객체를 반환 - 이 객체에는 요청의 결과에 대한 모든 정보가 담겨 있다.
response.status_code- HTTP 상태 코드 (예 :
200)
- HTTP 상태 코드 (예 :
response.text- 응답 본문을 문자열로 반환
response.json()- 응답 본문을 JSON으로 파싱하여 파이썬 딕셔너리나 리스트로 반환
response.headers- 응답 헤더를 딕셔너리로 반환
- 모든
AsyncClient 인스턴스의 다른 함수들
await client.post(url, data) await client.put(url, data) await client.delete(url) await client.patch(url, data) await client.head(url)raise_for_status 함수
async with httpx.AsyncClient() as client: response = await client.get(self.countries_url) response.raise_for_status() return [AvailableCountriesApiResponse(**item) for item in response.json()]- HTTP 요청의 상태 코드가 성공적이지 않을 때 예외를 발생시키는 편리한 메서드
동작 방식
- 응답 상태 코드(
response.status_code)가 특정 에러인 경우 예외 발생4xx번대 에러 또는5xx번대 에러일 경우- httpx.HTTPStatusError 예외 발생
- 만약 상태 코드가
2xx이거나3xx이라면 그냥 넘어간다
- 응답 상태 코드(
목적
- 명확한 에러 핸들링
- 코드 간결화
- 수동적인 상태 코드 체크 로직을 한 줄로 대체 가능
이 자리에 올 수 있는 다른 함수
if not response.is_successresponse객체의is_success속성을 직접 확인하는 방법
try-except블록raise_for_status()는 예외를 발생시키므로, 이를try-except블록으로 감싸 특정 에러에 대해 분기 가능- 예시 :
httpx.HTTPStatusError를 잡아서 로그 삽입
app/domains/holiday/schemas.py (수정)
...
class PublicHolidaysApiResponse(BaseModel):
date: Date
local_name: str = Field(..., alias="localName")
name: str
country_code: str = Field(..., alias="countryCode")
is_global: bool = Field(..., alias="global")
counties: Optional[List[str]] = None
launch_year: Optional[int] = Field(None, alias="launchYear")
types: Optional[List[str]] = None
class AvailableCountriesApiResponse(BaseModel):
country_code: str = Field(..., alias="countryCode")
name: str
class SyncDataRequest(BaseModel):
country_code: str = Field(..., description="동기화할 국가 코드", examples=["KR"])
year: int = Field(..., description="동기화할 연도", examples=[2024])
- Pydantic 모델 필드 명명
기본적으로 Pydantic 모델의 필드명은 JSON의 키와 일치해야 한다
부득이한 경우, **
Field(alias="…")**으로 맞춰 준다딕셔너리(JSON 데이터)를 언패킹을 통해 키워드 인자로 전환하여 Pydantic 모델로 값을 전달한다.
[AvailableCountriesApiResponse(**item) for item in response.json()]response.json()- 외부 API의 JSON 응답을 파이썬의 List로 변환
- 이 리스트의 각 요소는 딕셔너리이다
for item in ...- List 순회
AvailableCountriesApiResponse(**item)- 각 딕셔너리의 키-값 쌍을
AvailableCountriesApiResponse클래스의 생성자로 전달 (언패킹) - 이 때 딕셔너리의 키(
countryCode,name)가 Pydantic 필드 이름과 일치해야 값이 정상적으로 할당된다.
- 각 딕셔너리의 키-값 쌍을
- Pydantic은 이 인자들을 받아서 타입 유효성 검사를 수행
- 유효한 값들을 기반으로 새로운 Pydantic 모델 인스턴스를 생성
app/domains/holiday/services.py (수정)
# ...
from app.domains.country.crud import country_crud
from app.domains.holiday.schemas import PublicHolidaysApiResponse
from app.external.api_client import holiday_api_client
from app.core.exceptions import SyncException
class HolidayService(BaseService):
def __init__(self, crud: CRUDHoliday, country_crud: CRUDCountry):
super().__init__(crud)
self.crud = crud
self.country_crud = country_crud
....
async def sync_holidays_by_country(
self, db: AsyncSession, *, country_code: str, years: List[int]
):
try:
async with db.begin():
total_sync_count = 0
for year in years:
total_sync_count += await self._sync_holidays(
db, country_code=country_code, year=year
)
return total_sync_count
except Exception as e:
logger.warning(f"🔴 {country_code}의 공휴일 동기화 중 오류 발생: {str(e)}")
raise HolidaySyncException(
country_code=country_code, year=year, error_message="진행 중 오류 발생"
)
async def sync_holidays_from_api(
self, db: AsyncSession, *, country_code: str, year: int
):
try:
async with db.begin():
return await self._sync_holidays(
db, country_code=country_code, year=year
)
except Exception as e:
logger.warning(f"🔴 {country_code}의 {year}년 공휴일 동기화 중 오류 발생: {str(e)}")
raise HolidaySyncException(
country_code=country_code, year=year, error_message="진행 중 오류 발생"
)
async def _sync_holidays(
self, db: AsyncSession, *, country_code: str, year: int
) -> int:
country = await self.country_crud.get_by_country_code(db, country_code)
if not country:
raise CountryNotFoundException(field_name="국가 코드", value=country_code)
api_data: List[
PublicHolidaysApiResponse
] = await holiday_api_client.fetch_holidays_data(year, country_code)
if not api_data:
raise HolidaySyncException(
country_code=country_code, year=year, error_message="공휴일이 없습니다"
)
api_data_set = set()
unique_data = []
for item in api_data:
unique_key = (item.date, item.name, country.id)
if unique_key not in api_data_set:
api_data_set.add(unique_key)
unique_data.append(item)
holidays = [
Holiday(
date=item.date,
local_name=item.local_name,
name=item.name,
country_id=country.id,
is_global=item.is_global,
counties=",".join(item.counties) if item.counties is not None else None,
launch_year=item.launch_year,
types=",".join(item.types) if item.types is not None else None,
holiday_year=item.date.year,
)
for item in unique_data
]
await self.crud.remove_all_by(db, country_id=country.id, year=year)
await self.crud.create_all(db, holidays)
return len(holidays)
공휴일 외부 API 호출 로직 구현
- api_client 파일의
holiday_api_client인스턴스를 이용하여 공휴일 데이터 가져오기- JSON 파일을 딕셔너리 리스트로 받아오기
- 딕셔너리 리스트를 DTO 리스트로 전환
- DTO 리스트를 Holiday 모델 리스트로 전환
- 해당 범위의 DB 데이터 싹 지우기
- 새로운 Holiday 리스트 데이터 DB에 저장
- api_client 파일의
국가코드로 국가 조회하기
async def _sync_holidays( self, db: AsyncSession, *, country_code: str, year: int ) -> int: country = await self.country_crud.get_by_country_code(db, country_code) if not country: raise CountryNotFoundException(field_name="국가 코드", value=country_code)- 이용 가능한 국가를 미리 조회 저장했으므로, 해당 국가가 없다면 예외 발생
API 클라이언트를 사용하여 외부 데이터 가져오기
api_data: List[ PublicHolidaysApiResponse ] = await holiday_api_client.fetch_holidays_data(year, country_code) if not api_data: raise HolidaySyncException( country_code=country_code, year=year, error_message="공휴일이 없습니다" )- holiday_api_client의
fetch_holidays_data()사용
- holiday_api_client의
DTO를 Holiday 모델로 전환
api_data_set = set() unique_data = [] for item in api_data: unique_key = (item.date, item.name, country.id) if unique_key not in api_data_set: api_data_set.add(unique_key) unique_data.append(item) holidays = [ Holiday( date=item.date, local_name=item.local_name, name=item.name, country_id=country.id, is_global=item.is_global, counties=",".join(item.counties) if item.counties is not None else None, launch_year=item.launch_year, types=",".join(item.types) if item.types is not None else None, holiday_year=item.date.year, ) for item in unique_data ]unique_data리스트는 외부 API 자체에 중복되는 데이터가 있는 경우 처리하기 위함이다.counties,types의 경우 DB에 CSV 형태의 문자열로 저장되도록 처음에 설정했으므로 그에 맞게 변형이 필요하다.holiday_year가 빠져None으로 들어가면 바로 에러가 발생하니 주의하자
DB에 데이터를 저장
await self.crud.remove_all_by(db, country_id=country.id, year=year) await self.crud.create_all(db, holidays) return len(holidays)- 먼저 해당 연도, 국가의 기존 데이터를 삭제하고 새로운 데이터를 저장
- 트랜잭션 관리는 상위 함수에서 관리
app/domains/holiday/router.py (수정)
# ... (기존 임포트)
from fastapi import APIRouter, Query, BackgroundTasks
router = APIRouter(prefix="/holidays", tags=["holidays"])
# ...
@router.post(
"/sync",
response_model=BaseResponse,
description="특정 연도 및 국가 데이터 동기화 (Refresh)",
)
async def sync_holidays(
background_task: BackgroundTasks,
sync_request: SyncDataRequest,
holiday_service: HolidayService = Depends(get_holiday_service),
db: AsyncSession = Depends(get_db),
):
background_task.add_task(
holiday_service.sync_holidays_from_api,
db=db,
year=sync_request.year,
country_code=sync_request.country_code,
)
return BaseResponse(code="204", message="공휴일 동기화 작업 백그라운드 시작", data=None)
테스트용 API 엔드포인트 추가
- Spring 스케줄러처럼 API 요청에 직접 로직을 실행하는 대신, 백그라운드 태스크로 넘겨서 웹 서버의 응답성을 유지한다
- 클라이언트는 즉시 응답을 받고, 서버는 요청을 처리하는 메인 프로세스를 블로킹하지 않고 작업을 이어나갈 수 있다.
BackgroundTasks
- FastAPI가 제공하는 클래스
- FastAPI 엔드포인트 함수에 의존성으로 주입된다
- HTTP 요청이 완료된 후 백그라운드에서 실행될 작업을 등록하는 기능
add_task( )
- 백그라운드에서 실행할 함수를 등록
- 첫 번째 인자는 호출할 함수
- 이후의 모든 인자는 그 함수의 인수로 전달된다
- 작업 실행 시점
add_task()로 등록된 모든 작업은 라우터 함수가 완료되어 클라이언트로 HTTP 응답이 반환된 직후에 실행된다- “나중에 실행해 줘"라고 등록만 해두는 것
app/domains/country/services.py (수정)
# ...
from app.domains.holiday.schemas import AvailableCountriesApiResponse
from app.external.api_client import holiday_api_client
from app.core.exceptions import SyncException
class CountryService(BaseService):
def __init__(self, crud: CRUDCountry):
super().__init__(crud)
self.crud = crud
....
async def sync_countries_from_api(self, db: AsyncSession) -> Dict[str, int]:
try:
async with db.begin():
api_data: List[
AvailableCountriesApiResponse
] = await holiday_api_client.fetch_countries_data()
if not api_data:
raise SyncException(
resource_type="Country", error_message="이용 가능한 국가가 없습니다"
)
api_country_codes = [item.country_code for item in api_data]
existing_countries = await self.crud.get_existing_all(
db, api_country_codes
)
existing_countries_dict = {
country.country_code: country for country in existing_countries
}
new_countries = []
updated_count = 0
for item in api_data:
if item.country_code in existing_countries_dict:
existing_country = existing_countries_dict[item.country_code]
if existing_country.name != item.name:
logger.warning(
f"🟢 국가 업데이트 : {item.country_code} - {existing_country.name} -> {item.name}"
)
existing_country.name = item.name
updated_count += 1
else:
new_countries.append(
Country(
country_code=item.country_code,
name=item.name,
)
)
if new_countries:
await self.crud.create_all(db, new_countries)
return {
"available_countries": len(api_data),
"new_countries": len(new_countries),
"existing_countries": len(existing_countries),
}
except Exception as e:
logger.warning(f"🔴 국가 동기화 중 오류 발생: {str(e)}")
raise SyncException(resource_type="Country", error_message="진행 중 오류 발생")
async def get_countries_to_sync(
self, db: AsyncSession, threshold_hours: int = 24
) -> List[Country]:
threshold_time = datetime.now() - timedelta(hours=threshold_hours)
return await self.crud.get_all_by_sync_time(db, threshold_time)
async def update_sync_time_bulk(
self, db: AsyncSession, db_countries: List[Country]
):
try:
async with db.begin():
await self.crud.update_sync_time_bulk(db, db_countries=db_countries)
except Exception as e:
raise SyncException(
resource_type="Country", error_message="sync_time 업데이트 오류"
)
국가 정보 외부 API 호출 로직 구현
- api_client 파일의
holiday_api_client인스턴스를 이용하여 국가 데이터 가져오기- JSON 파일을 딕셔너리 리스트로 받아오기
- 딕셔너리 리스트를 DTO 리스트로 전환
- DTO 리스트를 Country 모델 리스트로 전환
- 가져온 목록과 기존 목록을 비교 작업
- 수정이 필요한 경우만 수정
- 신규 데이터만 따로 추출
- 새로운 Country 리스트 데이터 DB에 저장
- api_client 파일의
API 클라이언트를 사용하여 외부 데이터 가져오기
async with db.begin(): api_data: List[ AvailableCountriesApiResponse ] = await holiday_api_client.fetch_countries_data() if not api_data: raise SyncException( resource_type="Country", error_message="이용 가능한 국가가 없습니다" )- holiday_api_client의
fetch_countries_data()사용
- holiday_api_client의
기존 데이터와 신규 데이터 분리
api_country_codes = [item.country_code for item in api_data] existing_countries = await self.crud.get_existing_all( db, api_country_codes ) existing_countries_dict = { country.country_code: country for country in existing_countries } new_countries = [] updated_count = 0 for item in api_data: if item.country_code in existing_countries_dict: existing_country = existing_countries_dict[item.country_code] if existing_country.name != item.name: logger.warning( f"🟢 국가 업데이트 : {item.country_code} - {existing_country.name} -> {item.name}" ) existing_country.name = item.name updated_count += 1 else: new_countries.append( Country( country_code=item.country_code, name=item.name, ) )- 기존 데이터의 경우, 수정이 필요한 경우만 수정
- 신규 데이터만 DTO를 Country 모델로 전환
DB에 데이터를 저장
if new_countries: await self.crud.create_all(db, new_countries) return { "available_countries": len(api_data), "new_countries": len(new_countries), "existing_countries": len(existing_countries), }- 새롭게 추출한 신규 데이터만 DB 삽입
- 트랜잭션 관리는 서비스 Layer에서 진행
app/domains/holiday/router.py (수정)
@router.post(
"/sync",
response_model=BaseResponse,
description="국가 데이터 동기화",
)
async def sync_holidays(
background_task: BackgroundTasks,
country_service: CountryService = Depends(get_country_service),
db: AsyncSession = Depends(get_db),
):
background_task.add_task(
country_service.sync_countries_from_api,
db=db,
)
return BaseResponse(code="204", message="국가 동기화 작업 백그라운드 시작", data=None)
app/core/celery_app.py
from celery import Celery
from celery.schedules import crontab
import app.tasks.holiday_tasks # noqa: F401
from app.core.config import settings
celery_app = Celery("holiday_keeper")
celery_app.conf.broker_url = settings.celery_broker_url
celery_app.conf.result_backend = settings.celery_result_backend
celery_app.conf.timezone = settings.celery_timezone
celery_app.conf.enable_utc = settings.celery_enable_utc
celery_app.autodiscover_tasks(["app.tasks"])
beat_schedule = {
"annual-sync": {
"task": "app.tasks.holiday_tasks.sync_holidays",
"schedule": crontab(hour=2, minute=0, day_of_month=2, month_of_year=1),
}
}
if settings.celery_test_sync_enabled:
beat_schedule["test-sync"] = {
"task": "app.tasks.holiday_tasks.sync_holidays",
"schedule": crontab(minute="5"),
}
celery_app.conf.beat_schedule = beat_schedule
환경 변수로 Celery 설정
celery_app.conf.broker_url = settings.celery_broker_url celery_app.conf.result_backend = settings.celery_result_backend celery_app.conf.timezone = settings.celery_timezone celery_app.conf.enable_utc = settings.celery_enable_utc환경 변수로 Test Task 관리
beat_schedule = { "annual-sync": { "task": "app.tasks.holiday_tasks.sync_holidays", "schedule": crontab(hour=2, minute=0, day_of_month=2, month_of_year=1), } } if settings.celery_test_sync_enabled: beat_schedule["test-sync"] = { "task": "app.tasks.holiday_tasks.sync_holidays", "schedule": crontab(minute="5"), } celery_app.conf.beat_schedule = beat_schedule- 환경 변수 설정에 따라 Task 삽입 여부를 결정한다.
app/tasks/holiday_tasks.py
from __future__ import annotations
import asyncio
import logging
from contextlib import asynccontextmanager
from datetime import datetime
from typing import List
from celery import shared_task
from app.core.database import AsyncSessionLocal, async_engine
from app.core.deps import get_country_service, get_holiday_service
from app.domains.country.models import Country
logger = logging.getLogger(__name__)
@asynccontextmanager
async def get_async_app_context():
try:
yield
finally:
await async_engine.dispose()
@shared_task(bind=True, name="app.tasks.holiday_tasks.init_holidays")
def init_holidays(self):
logger.info("🟢 Celery Task - init_holidays 시작")
async def _run():
async with get_async_app_context():
current_year = datetime.now().year
years = [current_year - i for i in range(5)]
await run_sync_task(years)
logger.info(f"🟢 Celery Task - sync_holidays 완료 (연도: {years})")
try:
asyncio.run(_run())
except Exception as e:
logger.error(f"🔴 Celery Task - init_holidays 실행 중 오류 발생 : {e}")
raise self.retry(exc=e, countdown=60, max_retries=3)
@shared_task(bind=True, name="app.tasks.holiday_tasks.sync_holidays")
def sync_holidays(self):
logger.info("🟢 Celery Task - sync_holidays 시작")
async def _run():
async with get_async_app_context():
current_year = datetime.now().year
years = [current_year, current_year - 1]
await run_sync_task(years)
logger.info(f"🟢 Celery Task - sync_holidays 완료 (연도: {years})")
try:
asyncio.run(_run())
except Exception as e:
logger.error(f"🔴 Celery Task - sync_holidays 실행 중 오류 발생 : {e}")
raise self.retry(exc=e, countdown=60, max_retries=3)
async def run_sync_task(years: List[int]):
try:
async with AsyncSessionLocal() as db:
country_service = get_country_service()
holiday_service = get_holiday_service()
logger.info("🟢 국가 데이터 동기화 시작")
result_dict = await country_service.sync_countries_from_api(db)
logger.info(f"🟢 {result_dict['available_countries']}개 국가 데이터 동기화 완료")
logger.warning(
f"🟢 신규 : {result_dict['new_countries']} / 기존 : {result_dict['existing_countries']}"
)
async with db.begin():
countries_to_sync = await country_service.get_countries_to_sync(
db, threshold_hours=24
)
logger.warning(f"🟢 동기화 대상 국가: {len(countries_to_sync)}개")
error_country_codes: List[str] = []
success_country_list: List[Country] = []
total_sync_count = 0
for country in countries_to_sync:
logger.info(f"🟢 국가 코드 {country.country_code} : 공휴일 데이터 동기화 시작")
try:
total_sync_count += await holiday_service.sync_holidays_by_country(
db, country_code=country.country_code, years=years
)
success_country_list.append(country)
except Exception as e:
error_country_codes.append(country.country_code)
logger.warning(
f"🔴 국가 코드 {country.country_code} : 공휴일 동기화 중 오류 발생: {e}"
)
continue
await country_service.update_sync_time_bulk(db, success_country_list)
logger.warning(f"🟢 {years}년 공휴일 {total_sync_count}개 데이터 동기화 완료")
if error_country_codes:
logger.warning(f"🔴 실패 국가 목록 {error_country_codes}")
else:
logger.info(f"🟢 실패 국가 없음")
except Exception as e:
logger.error(f"🔴 동기화 작업 중 에러 발생 : {e}")
raise
app/core/lifespan.py
import logging
from contextlib import asynccontextmanager
from fastapi import FastAPI
from app.core.config import settings
logger = logging.getLogger("uvicorn")
@asynccontextmanager
async def lifespan(app: FastAPI):
logger.warning("🟢 FastAPI 애플리케이션 시작")
try:
# 초기 데이터 로딩
if settings.is_prod:
logger.info("🟢 초기 데이터 적재 시작")
await trigger_init_holidays()
else:
logger.info("🟢 개발환경 : 초기 데이터 로딩 Skip")
except Exception as e:
logger.error(f"🔴 초기화 작업 중 오류 발생: {e}")
raise
yield
logger.warning("🟢 FastAPI 애플리케이션 종료 ")
async def trigger_init_holidays():
try:
from app.core.celery_app import celery_app
task = celery_app.send_task("app.tasks.holiday_tasks.init_holidays")
logger.info(f"🟢 초기 데이터 로딩 작업 시작됨 (Task ID : {task.id})")
except Exception as e:
logger.error(f"🔴 초기 데이터 로딩 실패: {e}")
- Celery 활용 초기 데이터 적재하기
app/main.py
app = FastAPI(
title="Holiday Keeper API",
description="공휴일 정보 관리 API",
version="1.0.0",
docs_url="/docs",
redoc_url="/redoc",
lifespan=lifespan,
)
- uvicorn 실행 시 --lifespan on 설정
*uvicorn* app.main:app --reload --host $SERVER_HOST --port $SERVER_PORT --log-level debug --lifespan on