Skip to content

Feature/46 retrieve성능 향상 #76

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions interface/lang2sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,18 @@ def execute_query(
*,
query: str,
database_env: str,
retriever_name: str = "기본",
top_n: int = 5,
device: str = "cpu",
) -> dict:
"""
Lang2SQL 그래프를 실행하여 자연어 쿼리를 SQL 쿼리로 변환하고 결과를 반환합니다.

Args:
query (str): 자연어로 작성된 사용자 쿼리.
database_env (str): 사용할 데이터베이스 환경 설정 이름.
retriever_name (str): 사용할 검색기 이름.
top_n (int): 검색할 테이블 정보의 개수.

Returns:
dict: 변환된 SQL 쿼리 및 관련 메타데이터를 포함하는 결과 딕셔너리.
Expand All @@ -64,6 +69,9 @@ def execute_query(
"messages": [HumanMessage(content=query)],
"user_database_env": database_env,
"best_practice_query": "",
"retriever_name": retriever_name,
"top_n": top_n,
"device": device,
}
)

Expand Down Expand Up @@ -123,6 +131,33 @@ def display_result(
index=0,
)

device = st.selectbox(
"모델 실행 장치를 선택하세요:",
options=["cpu", "cuda"],
index=0,
)

retriever_options = {
"기본": "벡터 검색 (기본)",
"Reranker": "Reranker 검색 (정확도 향상)",
}

user_retriever = st.selectbox(
"검색기 유형을 선택하세요:",
options=list(retriever_options.keys()),
format_func=lambda x: retriever_options[x],
index=0,
)

user_top_n = st.slider(
"검색할 테이블 정보 개수:",
min_value=1,
max_value=20,
value=5,
step=1,
help="검색할 테이블 정보의 개수를 설정합니다. 값이 클수록 더 많은 테이블 정보를 검색하지만 처리 시간이 길어질 수 있습니다.",
)

st.sidebar.title("Output Settings")
for key, label in SIDEBAR_OPTIONS.items():
st.sidebar.checkbox(label, value=True, key=key)
Expand All @@ -131,5 +166,8 @@ def display_result(
result = execute_query(
query=user_query,
database_env=user_database_env,
retriever_name=user_retriever,
top_n=user_top_n,
device=device,
)
display_result(res=result, database=db)
4 changes: 4 additions & 0 deletions llm_utils/chains.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ def create_query_refiner_chain(llm):
[
SystemMessagePromptTemplate.from_template(prompt),
MessagesPlaceholder(variable_name="user_input"),
SystemMessagePromptTemplate.from_template(
"다음은 사용자의 실제 사용 가능한 테이블 및 컬럼 정보입니다:"
),
MessagesPlaceholder(variable_name="searched_tables"),
SystemMessagePromptTemplate.from_template(
"""
위 사용자의 입력을 바탕으로
Expand Down
56 changes: 16 additions & 40 deletions llm_utils/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
)

from llm_utils.tools import get_info_from_db
from llm_utils.retrieval import search_tables

# 노드 식별자 정의
QUERY_REFINER = "query_refiner"
Expand All @@ -31,6 +32,9 @@ class QueryMakerState(TypedDict):
best_practice_query: str
refined_input: str
generated_query: str
retriever_name: str
top_n: int
device: str


# 노드 함수: QUERY_REFINER 노드
Expand All @@ -40,6 +44,7 @@ def query_refiner_node(state: QueryMakerState):
"user_input": [state["messages"][0].content],
"user_database_env": [state["user_database_env"]],
"best_practice_query": [state["best_practice_query"]],
"searched_tables": [json.dumps(state["searched_tables"])],
}
)
state["messages"].append(res)
Expand All @@ -48,42 +53,13 @@ def query_refiner_node(state: QueryMakerState):


def get_table_info_node(state: QueryMakerState):
from langchain_community.vectorstores import FAISS
from langchain_openai import OpenAIEmbeddings

embeddings = OpenAIEmbeddings(model="text-embedding-3-small")
try:
db = FAISS.load_local(
os.getcwd() + "/table_info_db",
embeddings,
allow_dangerous_deserialization=True,
)
except:
documents = get_info_from_db()
db = FAISS.from_documents(documents, embeddings)
db.save_local(os.getcwd() + "/table_info_db")
doc_res = db.similarity_search(state["messages"][-1].content)
documents_dict = {}

for doc in doc_res:
lines = doc.page_content.split("\n")

# 테이블명 및 설명 추출
table_name, table_desc = lines[0].split(": ", 1)

# 컬럼 정보 추출
columns = {}
if len(lines) > 2 and lines[1].strip() == "Columns:":
for line in lines[2:]:
if ": " in line:
col_name, col_desc = line.split(": ", 1)
columns[col_name.strip()] = col_desc.strip()

# 딕셔너리 저장
documents_dict[table_name] = {
"table_description": table_desc.strip(),
**columns, # 컬럼 정보 추가
}
# retriever_name과 top_n을 이용하여 검색 수행
documents_dict = search_tables(
query=state["messages"][0].content,
retriever_name=state["retriever_name"],
top_n=state["top_n"],
device=state["device"],
)
state["searched_tables"] = documents_dict

return state
Expand Down Expand Up @@ -129,19 +105,19 @@ def query_maker_node_with_db_guide(state: QueryMakerState):

# StateGraph 생성 및 구성
builder = StateGraph(QueryMakerState)
builder.set_entry_point(QUERY_REFINER)
builder.set_entry_point(GET_TABLE_INFO)

# 노드 추가
builder.add_node(QUERY_REFINER, query_refiner_node)
builder.add_node(GET_TABLE_INFO, get_table_info_node)
builder.add_node(QUERY_REFINER, query_refiner_node)
builder.add_node(QUERY_MAKER, query_maker_node) # query_maker_node_with_db_guide
# builder.add_node(
# QUERY_MAKER, query_maker_node_with_db_guide
# ) # query_maker_node_with_db_guide

# 기본 엣지 설정
builder.add_edge(QUERY_REFINER, GET_TABLE_INFO)
builder.add_edge(GET_TABLE_INFO, QUERY_MAKER)
builder.add_edge(GET_TABLE_INFO, QUERY_REFINER)
builder.add_edge(QUERY_REFINER, QUERY_MAKER)

# QUERY_MAKER 노드 후 종료
builder.add_edge(QUERY_MAKER, END)
113 changes: 113 additions & 0 deletions llm_utils/retrieval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
import os
from langchain_community.vectorstores import FAISS
from langchain_openai import OpenAIEmbeddings
from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import CrossEncoderReranker
from langchain_community.cross_encoders import HuggingFaceCrossEncoder
from transformers import AutoModelForSequenceClassification, AutoTokenizer

from .tools import get_info_from_db


def get_vector_db():
"""벡터 데이터베이스를 로드하거나 생성합니다."""
embeddings = OpenAIEmbeddings(model="text-embedding-3-small")
try:
db = FAISS.load_local(
os.getcwd() + "/table_info_db",
embeddings,
allow_dangerous_deserialization=True,
)
except:
documents = get_info_from_db()
db = FAISS.from_documents(documents, embeddings)
db.save_local(os.getcwd() + "/table_info_db")
print("table_info_db not found")
return db


def load_reranker_model(device: str = "cpu"):
"""한국어 reranker 모델을 로드하거나 다운로드합니다."""
local_model_path = os.path.join(os.getcwd(), "ko_reranker_local")

# 로컬에 저장된 모델이 있으면 불러오고, 없으면 다운로드 후 저장
if os.path.exists(local_model_path) and os.path.isdir(local_model_path):
print("🔄 ko-reranker 모델 로컬에서 로드 중...")
else:
print("⬇️ ko-reranker 모델 다운로드 및 저장 중...")
model = AutoModelForSequenceClassification.from_pretrained(
"Dongjin-kr/ko-reranker"
)
tokenizer = AutoTokenizer.from_pretrained("Dongjin-kr/ko-reranker")
model.save_pretrained(local_model_path)
tokenizer.save_pretrained(local_model_path)

return HuggingFaceCrossEncoder(
model_name=local_model_path,
model_kwargs={"device": device},
)


def get_retriever(retriever_name: str = "기본", top_n: int = 5, device: str = "cpu"):
"""검색기 타입에 따라 적절한 검색기를 생성합니다.

Args:
retriever_name: 사용할 검색기 이름 ("기본", "재순위", 등)
top_n: 반환할 상위 결과 개수
"""
print(device)
retrievers = {
"기본": lambda: get_vector_db().as_retriever(search_kwargs={"k": top_n}),
"Reranker": lambda: ContextualCompressionRetriever(
base_compressor=CrossEncoderReranker(
model=load_reranker_model(device), top_n=top_n
),
base_retriever=get_vector_db().as_retriever(search_kwargs={"k": top_n}),
),
}

if retriever_name not in retrievers:
print(
f"경고: '{retriever_name}' 검색기를 찾을 수 없습니다. 기본 검색기를 사용합니다."
)
retriever_name = "기본"

return retrievers[retriever_name]()


def search_tables(
query: str, retriever_name: str = "기본", top_n: int = 5, device: str = "cpu"
):
"""쿼리에 맞는 테이블 정보를 검색합니다."""
if retriever_name == "기본":
db = get_vector_db()
doc_res = db.similarity_search(query, k=top_n)
else:
retriever = get_retriever(
retriever_name=retriever_name, top_n=top_n, device=device
)
doc_res = retriever.invoke(query)

# 결과를 사전 형태로 변환
documents_dict = {}
for doc in doc_res:
lines = doc.page_content.split("\n")

# 테이블명 및 설명 추출
table_name, table_desc = lines[0].split(": ", 1)

# 컬럼 정보 추출
columns = {}
if len(lines) > 2 and lines[1].strip() == "Columns:":
for line in lines[2:]:
if ": " in line:
col_name, col_desc = line.split(": ", 1)
columns[col_name.strip()] = col_desc.strip()

# 딕셔너리 저장
documents_dict[table_name] = {
"table_description": table_desc.strip(),
**columns, # 컬럼 정보 추가
}

return documents_dict
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ pre_commit==4.1.0
setuptools
wheel
twine
transformers==4.51.2
langchain-aws>=0.2.21,<0.3.0
langchain-google-genai>=2.1.3,<3.0.0
langchain-ollama>=0.3.2,<0.4.0
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
"langchain-google-genai>=2.1.3,<3.0.0",
"langchain-ollama>=0.3.2,<0.4.0",
"langchain-huggingface>=0.1.2,<0.2.0",
"transformers==4.51.2",
],
entry_points={
"console_scripts": [
Expand Down