Skip to content

text-splitters: Set strict mypy rules #30900

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions libs/text-splitters/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,9 @@ lint_tests: MYPY_CACHE=.mypy_cache_test

lint lint_diff lint_package lint_tests:
./scripts/lint_imports.sh
[ "$(PYTHON_FILES)" = "" ] || uv run --group typing --group lint ruff check $(PYTHON_FILES)
[ "$(PYTHON_FILES)" = "" ] || uv run --group typing --group lint ruff format $(PYTHON_FILES) --diff
[ "$(PYTHON_FILES)" = "" ] || mkdir -p $(MYPY_CACHE) && uv run --group typing --group lint mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE)
[ "$(PYTHON_FILES)" = "" ] || uv run --all-groups ruff check $(PYTHON_FILES)
[ "$(PYTHON_FILES)" = "" ] || uv run --all-groups ruff format $(PYTHON_FILES) --diff
[ "$(PYTHON_FILES)" = "" ] || mkdir -p $(MYPY_CACHE) && uv run --all-groups mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE)

format format_diff:
[ "$(PYTHON_FILES)" = "" ] || uv run --all-groups ruff format $(PYTHON_FILES)
Expand Down
2 changes: 1 addition & 1 deletion libs/text-splitters/langchain_text_splitters/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def split_text(self, text: str) -> List[str]:
"""Split text into multiple components."""

def create_documents(
self, texts: List[str], metadatas: Optional[List[dict]] = None
self, texts: list[str], metadatas: Optional[list[dict[Any, Any]]] = None
) -> List[Document]:
"""Create documents from a list of texts."""
_metadatas = metadatas or [{}] * len(texts)
Expand Down
24 changes: 11 additions & 13 deletions libs/text-splitters/langchain_text_splitters/html.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,8 +353,8 @@ def split_text(self, text: str) -> List[Document]:
return self.split_text_from_file(StringIO(text))

def create_documents(
self, texts: List[str], metadatas: Optional[List[dict]] = None
) -> List[Document]:
self, texts: list[str], metadatas: Optional[list[dict[Any, Any]]] = None
) -> list[Document]:
"""Create documents from a list of texts."""
_metadatas = metadatas or [{}] * len(texts)
documents = []
Expand Down Expand Up @@ -389,10 +389,8 @@ def split_html_by_headers(self, html_doc: str) -> List[Dict[str, Optional[str]]]
- 'tag_name': The name of the header tag (e.g., "h1", "h2").
"""
try:
from bs4 import (
BeautifulSoup, # type: ignore[import-untyped]
PageElement,
)
from bs4 import BeautifulSoup
from bs4.element import PageElement
except ImportError as e:
raise ImportError(
"Unable to import BeautifulSoup/PageElement, \
Expand All @@ -411,13 +409,13 @@ def split_html_by_headers(self, html_doc: str) -> List[Dict[str, Optional[str]]]
if i == 0:
current_header = "#TITLE#"
current_header_tag = "h1"
section_content: List = []
section_content: list[str] = []
else:
current_header = header_element.text.strip()
current_header_tag = header_element.name # type: ignore[attr-defined]
section_content = []
for element in header_element.next_elements:
if i + 1 < len(headers) and element == headers[i + 1]:
if i + 1 < len(headers) and element == headers[i + 1]: # type: ignore[comparison-overlap]
break
if isinstance(element, str):
section_content.append(element)
Expand Down Expand Up @@ -637,8 +635,8 @@ def __init__(

if self._stopword_removal:
try:
import nltk # type: ignore
from nltk.corpus import stopwords # type: ignore
import nltk
from nltk.corpus import stopwords # type: ignore[import-untyped]

nltk.download("stopwords")
self._stopwords = set(stopwords.words(self._stopword_lang))
Expand Down Expand Up @@ -902,7 +900,7 @@ def _process_element(
return documents

def _create_documents(
self, headers: dict, content: str, preserved_elements: dict
self, headers: dict[str, str], content: str, preserved_elements: dict[str, str]
) -> List[Document]:
"""Creates Document objects from the provided headers, content, and elements.

Expand All @@ -928,7 +926,7 @@ def _create_documents(
return self._further_split_chunk(content, metadata, preserved_elements)

def _further_split_chunk(
self, content: str, metadata: dict, preserved_elements: dict
self, content: str, metadata: dict[Any, Any], preserved_elements: dict[str, str]
) -> List[Document]:
"""Further splits the content into smaller chunks.

Expand Down Expand Up @@ -959,7 +957,7 @@ def _further_split_chunk(
return result

def _reinsert_preserved_elements(
self, content: str, preserved_elements: dict
self, content: str, preserved_elements: dict[str, str]
) -> str:
"""Reinserts preserved elements into the content into their original positions.

Expand Down
22 changes: 11 additions & 11 deletions libs/text-splitters/langchain_text_splitters/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,12 @@ def __init__(
)

@staticmethod
def _json_size(data: Dict) -> int:
def _json_size(data: dict[str, Any]) -> int:
"""Calculate the size of the serialized JSON object."""
return len(json.dumps(data))

@staticmethod
def _set_nested_dict(d: Dict, path: List[str], value: Any) -> None:
def _set_nested_dict(d: dict[str, Any], path: list[str], value: Any) -> None:
"""Set a value in a nested dictionary based on the given path."""
for key in path[:-1]:
d = d.setdefault(key, {})
Expand All @@ -76,10 +76,10 @@ def _list_to_dict_preprocessing(self, data: Any) -> Any:

def _json_split(
self,
data: Dict[str, Any],
current_path: Optional[List[str]] = None,
chunks: Optional[List[Dict]] = None,
) -> List[Dict]:
data: dict[str, Any],
current_path: Optional[list[str]] = None,
chunks: Optional[list[dict[str, Any]]] = None,
) -> list[dict[str, Any]]:
"""Split json into maximum size dictionaries while preserving structure."""
current_path = current_path or []
chunks = chunks if chunks is not None else [{}]
Expand Down Expand Up @@ -107,9 +107,9 @@ def _json_split(

def split_json(
self,
json_data: Dict[str, Any],
json_data: dict[str, Any],
convert_lists: bool = False,
) -> List[Dict]:
) -> list[dict[str, Any]]:
"""Splits JSON into a list of JSON chunks."""
if convert_lists:
chunks = self._json_split(self._list_to_dict_preprocessing(json_data))
Expand All @@ -135,11 +135,11 @@ def split_text(

def create_documents(
self,
texts: List[Dict],
texts: list[dict[str, Any]],
convert_lists: bool = False,
ensure_ascii: bool = True,
metadatas: Optional[List[dict]] = None,
) -> List[Document]:
metadatas: Optional[list[dict[Any, Any]]] = None,
) -> list[Document]:
"""Create documents from a list of json objects (Dict)."""
_metadatas = metadatas or [{}] * len(texts)
documents = []
Expand Down
6 changes: 3 additions & 3 deletions libs/text-splitters/langchain_text_splitters/markdown.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,18 +404,18 @@ def _complete_chunk_doc(self) -> None:
self.current_chunk = Document(page_content="")

# Match methods
def _match_header(self, line: str) -> Union[re.Match, None]:
def _match_header(self, line: str) -> Union[re.Match[str], None]:
match = re.match(r"^(#{1,6}) (.*)", line)
# Only matches on the configured headers
if match and match.group(1) in self.splittable_headers:
return match
return None

def _match_code(self, line: str) -> Union[re.Match, None]:
def _match_code(self, line: str) -> Union[re.Match[str], None]:
matches = [re.match(rule, line) for rule in [r"^```(.*)", r"^~~~(.*)"]]
return next((match for match in matches if match), None)

def _match_horz(self, line: str) -> Union[re.Match, None]:
def _match_horz(self, line: str) -> Union[re.Match[str], None]:
matches = [
re.match(rule, line) for rule in [r"^\*\*\*+\n", r"^---+\n", r"^___+\n"]
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def __init__(
def _initialize_chunk_configuration(
self, *, tokens_per_chunk: Optional[int]
) -> None:
self.maximum_tokens_per_chunk = cast(int, self._model.max_seq_length)
self.maximum_tokens_per_chunk = self._model.max_seq_length

if tokens_per_chunk is None:
self.tokens_per_chunk = self.maximum_tokens_per_chunk
Expand Down Expand Up @@ -93,10 +93,10 @@ def count_tokens(self, *, text: str) -> int:

_max_length_equal_32_bit_integer: int = 2**32

def _encode(self, text: str) -> List[int]:
def _encode(self, text: str) -> list[int]:
token_ids_with_start_and_end_token_ids = self.tokenizer.encode(
text,
max_length=self._max_length_equal_32_bit_integer,
truncation="do_not_truncate",
)
return token_ids_with_start_and_end_token_ids
return cast("list[int]", token_ids_with_start_and_end_token_ids)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Surprised we have to cast this

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's because mypy can't figure the type so it uses Any which is incompatible with the method signature.

10 changes: 7 additions & 3 deletions libs/text-splitters/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ repository = "https://github.com/langchain-ai/langchain"
[dependency-groups]
lint = ["ruff<1.0.0,>=0.9.2", "langchain-core"]
typing = [
"mypy<2.0,>=1.10",
"mypy<2.0,>=1.15",
"lxml-stubs<1.0.0,>=0.5.1",
"types-requests<3.0.0.0,>=2.31.0.20240218",
"tiktoken<1.0.0,>=0.8.0",
Expand Down Expand Up @@ -48,7 +48,11 @@ test_integration = [
langchain-core = { path = "../core", editable = true }

[tool.mypy]
disallow_untyped_defs = "True"
strict = "True"
strict_bytes = "True"
enable_error_code = "deprecated"
report_deprecated_as_note = "True"

[[tool.mypy.overrides]]
module = [
"transformers",
Expand All @@ -70,7 +74,7 @@ ignore_missing_imports = "True"
target-version = "py39"

[tool.ruff.lint]
select = ["E", "F", "I", "T201", "D"]
select = ["E", "F", "I", "PGH003", "T201", "D"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add UP (pyupgrade)? Reminded of this bc of the List -> list changes.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can be done in a different PR

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, can be done in another PR

ignore = ["D100"]

[tool.coverage.run]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def spacy() -> Any:
import spacy
except ImportError:
pytest.skip("Spacy not installed.")
spacy.cli.download("en_core_web_sm") # type: ignore
spacy.cli.download("en_core_web_sm") # type: ignore[attr-defined,operator,unused-ignore]
return spacy


Expand Down
Loading
Loading