Skip to content
This repository was archived by the owner on Feb 21, 2025. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 61 additions & 0 deletions knowledge_gpt/core/parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@
from typing import List, Any, Optional
import re

from PIL import Image
import cv2
import pytesseract
import numpy as np

import docx2txt
from langchain.docstore.document import Document
import fitz
Expand Down Expand Up @@ -94,7 +99,60 @@ def from_bytes(cls, file: BytesIO) -> "TxtFile":
doc = Document(page_content=text.strip())
doc.metadata["source"] = "p-1"
return cls(name=file.name, id=md5(file.read()).hexdigest(), docs=[doc])


class ImgFile(File):
@classmethod
def img2txt(cls, bytesio_image: BytesIO) -> str:
pil_image = Image.open(bytesio_image)
image = np.array(pil_image)
# if there is alpha channel, ignore it
if image.shape[2] == 4:
image = image[:, :, :3]
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
thresh = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)[1]

# Remove horizontal lines
horizontal_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (50,1))
detect_horizontal = cv2.morphologyEx(thresh, cv2.MORPH_OPEN, horizontal_kernel, iterations=2)
cnts = cv2.findContours(detect_horizontal, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
cnts = cnts[0] if len(cnts) == 2 else cnts[1]
for c in cnts:
cv2.drawContours(thresh, [c], -1, (0,0,0), 2)

# Remove vertical lines
vertical_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (1,15))
detect_vertical = cv2.morphologyEx(thresh, cv2.MORPH_OPEN, vertical_kernel, iterations=2)
cnts = cv2.findContours(detect_vertical, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
cnts = cnts[0] if len(cnts) == 2 else cnts[1]
for c in cnts:
cv2.drawContours(thresh, [c], -1, (0,0,0), 3)

# Dilate to connect text and remove dots
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (10,1))
dilate = cv2.dilate(thresh, kernel, iterations=2)
cnts = cv2.findContours(dilate, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
cnts = cnts[0] if len(cnts) == 2 else cnts[1]
for c in cnts:
area = cv2.contourArea(c)
if area < 500:
cv2.drawContours(dilate, [c], -1, (0,0,0), -1)

# Bitwise-and to reconstruct image
result = cv2.bitwise_and(image, image, mask=dilate)
result[dilate==0] = (255,255,255)

data = pytesseract.image_to_string(result)

return data

@classmethod
def from_bytes(cls, file: BytesIO) -> "ImgFile":
text = cls.img2txt(file)
text = strip_consecutive_newlines(text)
doc = Document(page_content=text.strip())
doc.metadata["source"] = "p-1"
return cls(name=file.name, id=md5(file.read()).hexdigest(), docs=[doc])

def read_file(file: BytesIO) -> File:
"""Reads an uploaded file and returns a File object"""
Expand All @@ -104,5 +162,8 @@ def read_file(file: BytesIO) -> File:
return PdfFile.from_bytes(file)
elif file.name.lower().endswith(".txt"):
return TxtFile.from_bytes(file)
elif file.name.lower().endswith(".png") or file.name.lower().endswith(".jpeg") \
or file.name.lower().endswith(".tiff") or file.name.lower().endswith(".jpg"):
return ImgFile.from_bytes(file)
else:
raise NotImplementedError(f"File type {file.name.split('.')[-1]} not supported")
4 changes: 2 additions & 2 deletions knowledge_gpt/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@


uploaded_file = st.file_uploader(
"Upload a pdf, docx, or txt file",
type=["pdf", "docx", "txt"],
"Upload a pdf, docx, txt, png, tiff, jpg, or jpeg file",
type=["pdf", "docx", "txt", "png", "tiff", "jpg", "jpeg"],
help="Scanned documents are not supported yet!",
)

Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ pydeck==0.8.0 ; python_version >= "3.10" and python_version < "4.0"
pygments==2.15.1 ; python_version >= "3.10" and python_version < "4.0"
pympler==1.0.1 ; python_version >= "3.10" and python_version < "4.0"
pymupdf==1.22.5 ; python_version >= "3.10" and python_version < "4.0"
pytesseract==0.3.10 ; python_version >= "3.10" and python_version < "4.0"
python-dateutil==2.8.2 ; python_version >= "3.10" and python_version < "4.0"
python-dotenv==0.21.1 ; python_version >= "3.10" and python_version < "4.0"
pytz-deprecation-shim==0.1.0.post0 ; python_version >= "3.10" and python_version < "4.0"
Expand Down
Binary file added resources/samples/test_hello.jpeg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
10 changes: 10 additions & 0 deletions tests/unit_tests/test_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
DocxFile,
PdfFile,
TxtFile,
ImgFile,
read_file,
strip_consecutive_newlines,
)
Expand Down Expand Up @@ -83,6 +84,15 @@ def test_txt_file():
assert len(txt_file.docs) == 1
assert txt_file.docs[0].page_content == "Hello World"

def test_img_file():
with open(SAMPLE_ROOT / "test_hello.jpeg", "rb") as f:
file = BytesIO(f.read())
file.name = "test_hello.jpeg"
img_file = ImgFile.from_bytes(file)

assert img_file.name == "test_hello.jpeg"
assert len(img_file.docs) == 1
assert img_file.docs[0].page_content.lower() == "Hello World".lower()

def test_read_file():
# Test the `read_file` function with each file type
Expand Down