반응형
문제상황
- 허깅페이스에 SQuAD 데이터를 불러와서 한국어로 번역하여 학습 데이터로 사용하기 위해, 구글번역기 api를 이용하여 이를 구현하려고 한다.
- 데이터를 살펴보니
context
컬럼에는 중복된 데이터가 많이 보인다. 중복된 데이터를 번역하기 위해 매번 API를 호출하면 비효율적이다. - 실행 시간뿐만 아니라 유료 api를 사용한다면 비용적으로도 많이 부담이 될 것이다.
코드 설명
의존성 설치
- 번역에 필요한
googletrans
와 진행 상황을 출력하기위한tqdm
패키지를 설치한다.
!pip install googletrans tqdm
캐시 구현
이미 한번 api를 호출한 데이터를 반복하지 않도록, 캐시 파일을 json으로 생성하고 관리하는 TranslationCache
클래스를 구현한다.
enter
와exit
함수를 통해 인스턴스 생성시 json 파일을 로드하고, 해제시 json 파일을 저장한다.get_translation
으로 캐시된 데이터를 가져오고,add_translation
으로 새로운 데이터를 캐시에 저장한다.
구글 번역 api 구현
구글 번역 api를 호출하는 함수와 캐시를 이용하여 효율적으로 번역을 수행하는 함수를 구현한다.
translate_with_retry
은 실제 api 호출부이며 입력받은 최대 재시도 횟수만큼 반복 시도한다.- 예시 코드에서는 간단하게 1초 대기 후 재시도하는 로직을 작성하지만, 표준 오류 처리 전략 중 하나인 지수 백오프 알고리즘을 구현하는 것도 좋은 방법이다.
SQuAD 데이터에 적용하기
SQuAD 데이터를 가져와서 Dataframe으로 변환하고, 각 column에 대해 번역 기능을 수행한다.
- 예시 코드에서는 SQuAD 데이터에서
context
와question
컬럼만 가져오고, validation 데이터 중 일부만 slice하여 사용한다. - 인스턴스 생성시 with로 감싸서 사용하면 자동으로 enter를 호출하고, 스코프에서 벗어날 때 자동으로 exit을 호출하고 메모리를 반환한다.
테스트
캐시 기능을 사용하지 않았을 때에 비해 api 호출 횟수 비교는 다음과 같다.
- context는 50 → 2
- paragraph는 50 → 47
캐시 파일에 정상적으로 원본 데이터와 번역 데이터가 쌍으로 저장되었다.
csv 파일도 번역 결과가 정상적으로 저장되었다.
결론
- Cache를 구현하여 효율적으로 API를 호출하여 시간과 비용을 절약할 수 있다.
- 구글 번역 API 뿐만 아니라 맞춤법 검사기, OpenAPI 등에도 응용하여 사용할 수 있다.
- 의도하지 않은 결과가 있다면 캐시 파일만 수정하여 일괄적으로 변경을 적용할 수 있다.
전체 소스 코드
import json
import os
import time
from datasets import load_dataset
from googletrans import Translator
import pandas as pd
from tqdm import tqdm
class TranslationCache:
def __init__(self, cache_file="translation_cache.json"):
self.cache_file = cache_file
self.cache = {}
def __enter__(self):
if os.path.exists(self.cache_file):
with open(self.cache_file, "r", encoding="utf-8") as f:
self.cache = json.load(f)
return self
def __exit__(self, exc_type, exc_value, traceback):
with open(self.cache_file, "w", encoding="utf-8") as f:
json.dump(self.cache, f, ensure_ascii=False, indent=2)
self.cache.clear() # 메모리 해제
def get_translation(self, text):
return self.cache.get(text)
def add_translation(self, text, translation):
self.cache[text] = translation
def call_googletrans_api_with_retry(translator, text, cache, max_retries=3):
# 캐시에서 번역 확인
cached_translation = cache.get_translation(text)
if cached_translation:
return cached_translation
# 캐시에 없는 경우 번역 수행
for attempt in range(max_retries):
try:
translated = translator.translate(text, src="en", dest="ko").text
time.sleep(0.01)
cache.add_translation(text, translated)
return translated
except Exception as e:
if attempt == max_retries - 1:
print(f"번역 실패: {str(e)}")
return text
time.sleep(1)
def translate(texts, cache):
translator = Translator()
translated_texts = []
# 중복 제거를 위해 유니크한 텍스트만 추출
unique_texts = list(set(texts))
for text in tqdm(unique_texts, desc="텍스트 번역 중"):
translated = call_googletrans_api_with_retry(translator, text, cache)
# 원본 순서대로 캐시에서 번역 가져오기
for text in texts:
translated = cache.get_translation(text)
translated_texts.append(translated)
return translated_texts
def squad_datasets_to_csv(slice_length, output_filename):
data = load_dataset("rajpurkar/squad")
df = pd.DataFrame(data["validation"])[["context", "question"]]
sliced_df = df.iloc[:slice_length]
sliced_df.to_csv(output_filename, index=False)
def translate_csv(input_filename, output_filename):
df = pd.read_csv(input_filename)
print("context 번역 중...")
with TranslationCache("context_cache.json") as context_cache:
df["context"] = translate(df["context"], context_cache)
print("질문 번역 중...")
with TranslationCache("question_cache.json") as question_cache:
df["question"] = translate(df["question"], question_cache)
df.to_csv(output_filename, index=False)
squad_datasets_to_csv(50, "squad_en_50.csv")
translate_csv("squad_en_50.csv", "squad_ko_50.csv")
반응형
댓글