Skip to content

Commit

Permalink
swtichin to orjson, some better json error handling
Browse files Browse the repository at this point in the history
  • Loading branch information
jakep-allenai committed Oct 25, 2024
1 parent 48a3aff commit 8e6d0c6
Showing 1 changed file with 36 additions and 26 deletions.
62 changes: 36 additions & 26 deletions pdelfin/birrpipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import hashlib
import boto3
import sqlite3
import json
import orjson
import argparse
import uuid
import tempfile
Expand All @@ -23,7 +23,7 @@
from concurrent.futures import ProcessPoolExecutor, as_completed

from pdelfin.data.renderpdf import render_pdf_to_base64png
from pdelfin.prompts import build_finetuning_prompt
from pdelfin.prompts import build_finetuning_prompt, PageResponse
from pdelfin.prompts.anchor import get_anchor_text
from pdelfin.s3_utils import parse_custom_id, expand_s3_glob, get_s3_bytes, parse_s3_path

Expand All @@ -36,6 +36,7 @@
logging.getLogger("pypdf").setLevel(logging.ERROR)
logging.getLogger("smart_open").setLevel(logging.ERROR)


class DatabaseManager:
@dataclass(frozen=True)
class BatchInferenceRecord:
Expand All @@ -57,15 +58,17 @@ class PDFRecord:
num_pages: int
status: str

def __init__(self, s3_workspace: str):
def __init__(self, s3_workspace: str, skip_init: bool=False):
cache_key = hashlib.sha256(s3_workspace.strip().lower().encode('utf-8')).hexdigest()
home_cache_dir = os.path.join(os.path.expanduser('~'), '.cache', 'pdelfin', cache_key)
os.makedirs(home_cache_dir, exist_ok=True)
self.db_path = os.path.join(home_cache_dir, 'index.db')

self.conn = sqlite3.connect(self.db_path)
self.cursor = self.conn.cursor()
self._initialize_tables()

if not skip_init:
self._initialize_tables()

def _initialize_tables(self):
self.cursor.execute("""
Expand Down Expand Up @@ -137,8 +140,11 @@ def update_processed_file(self, s3_path, etag):

def clear_index(self):
self.cursor.execute("""
DELETE FROM processed_files; DELETE FROM page_results;
""", (s3_path, etag))
DELETE FROM processed_files;
""")
self.cursor.execute("""
DELETE FROM page_results;
""")
self.conn.commit()

def add_index_entries(self, index_entries: List[BatchInferenceRecord]):
Expand Down Expand Up @@ -274,7 +280,7 @@ def write_line(self, obj: Optional[Any]):
if obj is None:
return

line_bytes = json.dumps(obj, ensure_ascii=False).encode("utf-8")
line_bytes = orjson.dumps(obj)
line_size = len(line_bytes) + 1 # +1 for newline

if self.batch_size + line_size > self.max_size:
Expand Down Expand Up @@ -393,9 +399,8 @@ def process_jsonl_content(inference_s3_path: str) -> List[DatabaseManager.BatchI
line_length = len(line) # Length in bytes

try:
# Decode the line for JSON processing
line_str = line.decode('utf-8')
data = json.loads(line_str)
# Parse the line directly as JSON
data = orjson.loads(line)
pdf_s3_path, page_num = parse_custom_id(data["custom_id"])

if data.get("completion_error", None) is not None:
Expand All @@ -414,11 +419,12 @@ def process_jsonl_content(inference_s3_path: str) -> List[DatabaseManager.BatchI
assert "outputs" in data and len(data["outputs"]) > 0, "No outputs from model detected"

try:
model_response_json = json.loads(data["outputs"][0]["text"])
model_response_json = orjson.loads(data["outputs"][0]["text"])
page_response = PageResponse(**model_response_json)

last_error = data.get("completion_error", None)

if not model_response_json["is_rotation_valid"]:
if not page_response.is_rotation_valid:
last_error = "rotation_invalid"

index_entries.append(DatabaseManager.BatchInferenceRecord(
Expand All @@ -431,7 +437,8 @@ def process_jsonl_content(inference_s3_path: str) -> List[DatabaseManager.BatchI
finish_reason=data["outputs"][0]["finish_reason"],
error=last_error,
))
except json.JSONDecodeError:
except Exception as e:
error_type = type(e).__name__
index_entries.append(DatabaseManager.BatchInferenceRecord(
inference_s3_path=inference_s3_path,
pdf_s3_path=pdf_s3_path,
Expand All @@ -440,14 +447,12 @@ def process_jsonl_content(inference_s3_path: str) -> List[DatabaseManager.BatchI
start_index=start_index, # Byte offset in the original file
length=line_length, # Length in bytes
finish_reason=data["outputs"][0]["finish_reason"],
error="Could not parse model JSON output",
error=error_type,
))

except json.JSONDecodeError:
print(f"Error with JSON Decoding of inference in {inference_s3_path}")
# TODO Maybe this needs to add an index error that this json is bad
except Exception as e:
print(f"Error processing line: {e}")
print(f"Error processing line in {inference_s3_path}: {e}")
# Optionally, you might want to add an index entry indicating an error here

start_index += line_length # Increment by the number of bytes

Expand All @@ -468,7 +473,7 @@ def get_pdf_num_pages(s3_path: str) -> Optional[int]:
return None

def build_pdf_queries(s3_workspace: str, pdf: DatabaseManager.PDFRecord, cur_round: int, target_longest_image_dim: int, target_anchor_text_len: int) -> list[dict]:
db = DatabaseManager(s3_workspace)
db = DatabaseManager(s3_workspace, skip_init=True)

existing_pages = db.get_index_entries(pdf.s3_path)
new_queries = []
Expand Down Expand Up @@ -504,7 +509,7 @@ def build_pdf_queries(s3_workspace: str, pdf: DatabaseManager.PDFRecord, cur_rou
return new_queries

def build_dolma_doc(s3_workspace: str, pdf: DatabaseManager.PDFRecord) -> Optional[dict]:
db = DatabaseManager(s3_workspace)
db = DatabaseManager(s3_workspace, skip_init=True)
existing_pages = db.get_index_entries(pdf.s3_path)
document_text = ""
last_page_start_index = 0
Expand All @@ -520,19 +525,24 @@ def build_dolma_doc(s3_workspace: str, pdf: DatabaseManager.PDFRecord) -> Option
start_index=page.start_index,
end_index=page.start_index + page.length - 1) for page in usable_pages]

usable_page_final_results = [json.loads(json.loads(page_data.decode("utf-8"))["outputs"][0]["text"]) for page_data in usable_page_data]
usable_page_final_results = []
for page_data in usable_page_data:
data = orjson.loads(page_data)
model_response_json = orjson.loads(data["outputs"][0]["text"])
page_response = PageResponse(**model_response_json)
usable_page_final_results.append(page_response)

# Sort the pages:
# 1. Prefer pages with `is_rotation_valid` set to True.
# 2. Within those, sort by the length of the `natural_text` in descending order.
usable_page_final_results.sort(
key=lambda page: (not page["is_rotation_valid"], -len(page["natural_text"] if page["natural_text"] else ""))
key=lambda page: (not page.is_rotation_valid, -len(page.natural_text or ""))
)

target_page_final_result = usable_page_final_results[0]

if target_page_final_result["natural_text"] is not None:
document_text += target_page_final_result["natural_text"] + "\n"
if target_page_final_result.natural_text is not None:
document_text += target_page_final_result.natural_text + "\n"

pdf_page_spans.append([last_page_start_index, len(document_text), target_page_num])
last_page_start_index = len(document_text)
Expand All @@ -558,7 +568,7 @@ def build_dolma_doc(s3_workspace: str, pdf: DatabaseManager.PDFRecord) -> Option
return dolma_doc

def mark_pdfs_done(s3_workspace: str, dolma_docs: list[dict]):
db = DatabaseManager(s3_workspace)
db = DatabaseManager(s3_workspace, skip_init=True)

for doc in dolma_docs:
db.update_pdf_status(doc["metadata"]["Source-File"], "completed")
Expand Down Expand Up @@ -599,7 +609,7 @@ def get_current_round(s3_workspace: str) -> int:
parser.add_argument('--workspace_profile', help='S3 configuration profile for accessing the workspace', default=None)
parser.add_argument('--pdf_profile', help='S3 configuration profile for accessing the raw pdf documents', default=None)
parser.add_argument('--max_size_mb', type=int, default=250, help='Max file size in MB')
parser.add_argument('--reindex', action='set_true', default=False, help='Reindex all of the page_results')
parser.add_argument('--reindex', action='store_true', default=False, help='Reindex all of the page_results')
args = parser.parse_args()

if args.workspace_profile:
Expand Down

0 comments on commit 8e6d0c6

Please sign in to comment.