From a6d7a76c37a3499bd0b7e6cd685122baa99e8202 Mon Sep 17 00:00:00 2001 From: ehddnr301 Date: Sun, 27 Apr 2025 06:25:26 +0000 Subject: [PATCH 1/2] =?UTF-8?q?feat:=20retrieval=20=EB=B6=84=EB=A6=AC=20?= =?UTF-8?q?=EB=B0=8F=20reranker=20=EC=B6=94=EA=B0=80=EC=9E=91=EC=97=85?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit #46, #52 --- interface/lang2sql.py | 29 ++++++++++++ llm_utils/chains.py | 4 ++ llm_utils/graph.py | 54 ++++++--------------- llm_utils/retrieval.py | 105 +++++++++++++++++++++++++++++++++++++++++ requirements.txt | 1 + setup.py | 1 + 6 files changed, 154 insertions(+), 40 deletions(-) create mode 100644 llm_utils/retrieval.py diff --git a/interface/lang2sql.py b/interface/lang2sql.py index b7a5905..3097906 100644 --- a/interface/lang2sql.py +++ b/interface/lang2sql.py @@ -46,6 +46,8 @@ def execute_query( *, query: str, database_env: str, + retriever_name: str = "기본", + top_n: int = 5, ) -> dict: """ Lang2SQL 그래프를 실행하여 자연어 쿼리를 SQL 쿼리로 변환하고 결과를 반환합니다. @@ -53,6 +55,8 @@ def execute_query( Args: query (str): 자연어로 작성된 사용자 쿼리. database_env (str): 사용할 데이터베이스 환경 설정 이름. + retriever_name (str): 사용할 검색기 이름. + top_n (int): 검색할 테이블 정보의 개수. Returns: dict: 변환된 SQL 쿼리 및 관련 메타데이터를 포함하는 결과 딕셔너리. @@ -64,6 +68,8 @@ def execute_query( "messages": [HumanMessage(content=query)], "user_database_env": database_env, "best_practice_query": "", + "retriever_name": retriever_name, + "top_n": top_n, } ) @@ -123,6 +129,27 @@ def display_result( index=0, ) +retriever_options = { + "기본": "벡터 검색 (기본)", + "Reranker": "Reranker 검색 (정확도 향상)", +} + +user_retriever = st.selectbox( + "검색기 유형을 선택하세요:", + options=list(retriever_options.keys()), + format_func=lambda x: retriever_options[x], + index=0, +) + +user_top_n = st.slider( + "검색할 테이블 정보 개수:", + min_value=1, + max_value=20, + value=5, + step=1, + help="검색할 테이블 정보의 개수를 설정합니다. 값이 클수록 더 많은 테이블 정보를 검색하지만 처리 시간이 길어질 수 있습니다.", +) + st.sidebar.title("Output Settings") for key, label in SIDEBAR_OPTIONS.items(): st.sidebar.checkbox(label, value=True, key=key) @@ -131,5 +158,7 @@ def display_result( result = execute_query( query=user_query, database_env=user_database_env, + retriever_name=user_retriever, + top_n=user_top_n, ) display_result(res=result, database=db) diff --git a/llm_utils/chains.py b/llm_utils/chains.py index 81d957e..a0a5f27 100644 --- a/llm_utils/chains.py +++ b/llm_utils/chains.py @@ -26,6 +26,10 @@ def create_query_refiner_chain(llm): [ SystemMessagePromptTemplate.from_template(prompt), MessagesPlaceholder(variable_name="user_input"), + SystemMessagePromptTemplate.from_template( + "다음은 사용자의 실제 사용 가능한 테이블 및 컬럼 정보입니다:" + ), + MessagesPlaceholder(variable_name="searched_tables"), SystemMessagePromptTemplate.from_template( """ 위 사용자의 입력을 바탕으로 diff --git a/llm_utils/graph.py b/llm_utils/graph.py index a6f5137..c0ed5c8 100644 --- a/llm_utils/graph.py +++ b/llm_utils/graph.py @@ -14,6 +14,7 @@ ) from llm_utils.tools import get_info_from_db +from llm_utils.retrieval import search_tables # 노드 식별자 정의 QUERY_REFINER = "query_refiner" @@ -31,6 +32,8 @@ class QueryMakerState(TypedDict): best_practice_query: str refined_input: str generated_query: str + retriever_name: str + top_n: int # 노드 함수: QUERY_REFINER 노드 @@ -40,6 +43,7 @@ def query_refiner_node(state: QueryMakerState): "user_input": [state["messages"][0].content], "user_database_env": [state["user_database_env"]], "best_practice_query": [state["best_practice_query"]], + "searched_tables": [json.dumps(state["searched_tables"])], } ) state["messages"].append(res) @@ -48,42 +52,12 @@ def query_refiner_node(state: QueryMakerState): def get_table_info_node(state: QueryMakerState): - from langchain_community.vectorstores import FAISS - from langchain_openai import OpenAIEmbeddings - - embeddings = OpenAIEmbeddings(model="text-embedding-3-small") - try: - db = FAISS.load_local( - os.getcwd() + "/table_info_db", - embeddings, - allow_dangerous_deserialization=True, - ) - except: - documents = get_info_from_db() - db = FAISS.from_documents(documents, embeddings) - db.save_local(os.getcwd() + "/table_info_db") - doc_res = db.similarity_search(state["messages"][-1].content) - documents_dict = {} - - for doc in doc_res: - lines = doc.page_content.split("\n") - - # 테이블명 및 설명 추출 - table_name, table_desc = lines[0].split(": ", 1) - - # 컬럼 정보 추출 - columns = {} - if len(lines) > 2 and lines[1].strip() == "Columns:": - for line in lines[2:]: - if ": " in line: - col_name, col_desc = line.split(": ", 1) - columns[col_name.strip()] = col_desc.strip() - - # 딕셔너리 저장 - documents_dict[table_name] = { - "table_description": table_desc.strip(), - **columns, # 컬럼 정보 추가 - } + # retriever_name과 top_n을 이용하여 검색 수행 + documents_dict = search_tables( + query=state["messages"][0].content, + retriever_name=state["retriever_name"], + top_n=state["top_n"], + ) state["searched_tables"] = documents_dict return state @@ -129,19 +103,19 @@ def query_maker_node_with_db_guide(state: QueryMakerState): # StateGraph 생성 및 구성 builder = StateGraph(QueryMakerState) -builder.set_entry_point(QUERY_REFINER) +builder.set_entry_point(GET_TABLE_INFO) # 노드 추가 -builder.add_node(QUERY_REFINER, query_refiner_node) builder.add_node(GET_TABLE_INFO, get_table_info_node) +builder.add_node(QUERY_REFINER, query_refiner_node) builder.add_node(QUERY_MAKER, query_maker_node) # query_maker_node_with_db_guide # builder.add_node( # QUERY_MAKER, query_maker_node_with_db_guide # ) # query_maker_node_with_db_guide # 기본 엣지 설정 -builder.add_edge(QUERY_REFINER, GET_TABLE_INFO) -builder.add_edge(GET_TABLE_INFO, QUERY_MAKER) +builder.add_edge(GET_TABLE_INFO, QUERY_REFINER) +builder.add_edge(QUERY_REFINER, QUERY_MAKER) # QUERY_MAKER 노드 후 종료 builder.add_edge(QUERY_MAKER, END) diff --git a/llm_utils/retrieval.py b/llm_utils/retrieval.py new file mode 100644 index 0000000..5e517f5 --- /dev/null +++ b/llm_utils/retrieval.py @@ -0,0 +1,105 @@ +import os +from langchain_community.vectorstores import FAISS +from langchain_openai import OpenAIEmbeddings +from langchain.retrievers import ContextualCompressionRetriever +from langchain.retrievers.document_compressors import CrossEncoderReranker +from langchain_community.cross_encoders import HuggingFaceCrossEncoder +from transformers import AutoModelForSequenceClassification, AutoTokenizer + +from .tools import get_info_from_db + + +def get_vector_db(): + """벡터 데이터베이스를 로드하거나 생성합니다.""" + embeddings = OpenAIEmbeddings(model="text-embedding-3-small") + try: + db = FAISS.load_local( + os.getcwd() + "/table_info_db", + embeddings, + allow_dangerous_deserialization=True, + ) + except: + documents = get_info_from_db() + db = FAISS.from_documents(documents, embeddings) + db.save_local(os.getcwd() + "/table_info_db") + print("table_info_db not found") + return db + + +def load_reranker_model(): + """한국어 reranker 모델을 로드하거나 다운로드합니다.""" + local_model_path = os.path.join(os.getcwd(), "ko_reranker_local") + + # 로컬에 저장된 모델이 있으면 불러오고, 없으면 다운로드 후 저장 + if os.path.exists(local_model_path) and os.path.isdir(local_model_path): + print("🔄 ko-reranker 모델 로컬에서 로드 중...") + else: + print("⬇️ ko-reranker 모델 다운로드 및 저장 중...") + model = AutoModelForSequenceClassification.from_pretrained( + "Dongjin-kr/ko-reranker", + ) + tokenizer = AutoTokenizer.from_pretrained("Dongjin-kr/ko-reranker") + model.save_pretrained(local_model_path) + tokenizer.save_pretrained(local_model_path) + + return HuggingFaceCrossEncoder(model_name=local_model_path) + + +def get_retriever(retriever_name: str = "기본", top_n: int = 5): + """검색기 타입에 따라 적절한 검색기를 생성합니다. + + Args: + retriever_name: 사용할 검색기 이름 ("기본", "재순위", 등) + top_n: 반환할 상위 결과 개수 + """ + retrievers = { + "기본": lambda: get_vector_db().as_retriever(search_kwargs={"k": top_n}), + "Reranker": lambda: ContextualCompressionRetriever( + base_compressor=CrossEncoderReranker( + model=load_reranker_model(), top_n=top_n + ), + base_retriever=get_vector_db().as_retriever(search_kwargs={"k": top_n}), + ), + } + + if retriever_name not in retrievers: + print( + f"경고: '{retriever_name}' 검색기를 찾을 수 없습니다. 기본 검색기를 사용합니다." + ) + retriever_name = "기본" + + return retrievers[retriever_name]() + + +def search_tables(query: str, retriever_name: str = "기본", top_n: int = 5): + """쿼리에 맞는 테이블 정보를 검색합니다.""" + if retriever_name == "기본": + db = get_vector_db() + doc_res = db.similarity_search(query, k=top_n) + else: + retriever = get_retriever(retriever_name=retriever_name, top_n=top_n) + doc_res = retriever.invoke(query) + + # 결과를 사전 형태로 변환 + documents_dict = {} + for doc in doc_res: + lines = doc.page_content.split("\n") + + # 테이블명 및 설명 추출 + table_name, table_desc = lines[0].split(": ", 1) + + # 컬럼 정보 추출 + columns = {} + if len(lines) > 2 and lines[1].strip() == "Columns:": + for line in lines[2:]: + if ": " in line: + col_name, col_desc = line.split(": ", 1) + columns[col_name.strip()] = col_desc.strip() + + # 딕셔너리 저장 + documents_dict[table_name] = { + "table_description": table_desc.strip(), + **columns, # 컬럼 정보 추가 + } + + return documents_dict diff --git a/requirements.txt b/requirements.txt index 2c506a8..86b2136 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,3 +11,4 @@ pre_commit==4.1.0 setuptools wheel twine +transformers==4.51.2 \ No newline at end of file diff --git a/setup.py b/setup.py index 71a31ac..78f0612 100644 --- a/setup.py +++ b/setup.py @@ -28,6 +28,7 @@ "langchain-google-genai>=2.1.3,<3.0.0", "langchain-ollama>=0.3.2,<0.4.0", "langchain-huggingface>=0.1.2,<0.2.0", + "transformers==4.51.2", ], entry_points={ "console_scripts": [ From 9656e676e9ff102a8ad91a75c73b7c239efaff12 Mon Sep 17 00:00:00 2001 From: ehddnr301 Date: Mon, 28 Apr 2025 14:04:35 +0000 Subject: [PATCH 2/2] =?UTF-8?q?feat:=20device=20=EC=84=A0=ED=83=9D?= =?UTF-8?q?=EA=B0=80=EB=8A=A5=ED=95=98=EB=8F=84=EB=A1=9D=20=EC=97=85?= =?UTF-8?q?=EB=8D=B0=EC=9D=B4=ED=8A=B8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- interface/lang2sql.py | 9 +++++++++ llm_utils/graph.py | 2 ++ llm_utils/retrieval.py | 22 +++++++++++++++------- 3 files changed, 26 insertions(+), 7 deletions(-) diff --git a/interface/lang2sql.py b/interface/lang2sql.py index 3097906..8704d95 100644 --- a/interface/lang2sql.py +++ b/interface/lang2sql.py @@ -48,6 +48,7 @@ def execute_query( database_env: str, retriever_name: str = "기본", top_n: int = 5, + device: str = "cpu", ) -> dict: """ Lang2SQL 그래프를 실행하여 자연어 쿼리를 SQL 쿼리로 변환하고 결과를 반환합니다. @@ -70,6 +71,7 @@ def execute_query( "best_practice_query": "", "retriever_name": retriever_name, "top_n": top_n, + "device": device, } ) @@ -129,6 +131,12 @@ def display_result( index=0, ) +device = st.selectbox( + "모델 실행 장치를 선택하세요:", + options=["cpu", "cuda"], + index=0, +) + retriever_options = { "기본": "벡터 검색 (기본)", "Reranker": "Reranker 검색 (정확도 향상)", @@ -160,5 +168,6 @@ def display_result( database_env=user_database_env, retriever_name=user_retriever, top_n=user_top_n, + device=device, ) display_result(res=result, database=db) diff --git a/llm_utils/graph.py b/llm_utils/graph.py index c0ed5c8..69a10b9 100644 --- a/llm_utils/graph.py +++ b/llm_utils/graph.py @@ -34,6 +34,7 @@ class QueryMakerState(TypedDict): generated_query: str retriever_name: str top_n: int + device: str # 노드 함수: QUERY_REFINER 노드 @@ -57,6 +58,7 @@ def get_table_info_node(state: QueryMakerState): query=state["messages"][0].content, retriever_name=state["retriever_name"], top_n=state["top_n"], + device=state["device"], ) state["searched_tables"] = documents_dict diff --git a/llm_utils/retrieval.py b/llm_utils/retrieval.py index 5e517f5..728141f 100644 --- a/llm_utils/retrieval.py +++ b/llm_utils/retrieval.py @@ -26,7 +26,7 @@ def get_vector_db(): return db -def load_reranker_model(): +def load_reranker_model(device: str = "cpu"): """한국어 reranker 모델을 로드하거나 다운로드합니다.""" local_model_path = os.path.join(os.getcwd(), "ko_reranker_local") @@ -36,27 +36,31 @@ def load_reranker_model(): else: print("⬇️ ko-reranker 모델 다운로드 및 저장 중...") model = AutoModelForSequenceClassification.from_pretrained( - "Dongjin-kr/ko-reranker", + "Dongjin-kr/ko-reranker" ) tokenizer = AutoTokenizer.from_pretrained("Dongjin-kr/ko-reranker") model.save_pretrained(local_model_path) tokenizer.save_pretrained(local_model_path) - return HuggingFaceCrossEncoder(model_name=local_model_path) + return HuggingFaceCrossEncoder( + model_name=local_model_path, + model_kwargs={"device": device}, + ) -def get_retriever(retriever_name: str = "기본", top_n: int = 5): +def get_retriever(retriever_name: str = "기본", top_n: int = 5, device: str = "cpu"): """검색기 타입에 따라 적절한 검색기를 생성합니다. Args: retriever_name: 사용할 검색기 이름 ("기본", "재순위", 등) top_n: 반환할 상위 결과 개수 """ + print(device) retrievers = { "기본": lambda: get_vector_db().as_retriever(search_kwargs={"k": top_n}), "Reranker": lambda: ContextualCompressionRetriever( base_compressor=CrossEncoderReranker( - model=load_reranker_model(), top_n=top_n + model=load_reranker_model(device), top_n=top_n ), base_retriever=get_vector_db().as_retriever(search_kwargs={"k": top_n}), ), @@ -71,13 +75,17 @@ def get_retriever(retriever_name: str = "기본", top_n: int = 5): return retrievers[retriever_name]() -def search_tables(query: str, retriever_name: str = "기본", top_n: int = 5): +def search_tables( + query: str, retriever_name: str = "기본", top_n: int = 5, device: str = "cpu" +): """쿼리에 맞는 테이블 정보를 검색합니다.""" if retriever_name == "기본": db = get_vector_db() doc_res = db.similarity_search(query, k=top_n) else: - retriever = get_retriever(retriever_name=retriever_name, top_n=top_n) + retriever = get_retriever( + retriever_name=retriever_name, top_n=top_n, device=device + ) doc_res = retriever.invoke(query) # 결과를 사전 형태로 변환