Skip to content

Commit

Permalink
Switching to logging vs prints
Browse files Browse the repository at this point in the history
  • Loading branch information
jakep-allenai committed Oct 28, 2024
1 parent a3e7654 commit 45269fa
Showing 1 changed file with 43 additions and 36 deletions.
79 changes: 43 additions & 36 deletions pdelfin/birrpipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@
from pdelfin.s3_utils import parse_custom_id, expand_s3_glob, get_s3_bytes, parse_s3_path


# Initialize logger
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

# Global s3 client for the whole script, feel free to adjust params if you need it
workspace_s3 = boto3.client('s3')
pdf_s3 = boto3.client('s3')
Expand Down Expand Up @@ -65,7 +69,9 @@ def __init__(self, s3_workspace: str, skip_init: bool=False):
self.db_path = os.path.join(home_cache_dir, 'index.db')

self.conn = sqlite3.connect(self.db_path)
# Enable WAL mode so you can read and write concurrently
self.cursor = self.conn.cursor()
self.cursor.execute("PRAGMA journal_mode=WAL;")

if not skip_init:
self._initialize_tables()
Expand Down Expand Up @@ -147,15 +153,15 @@ def clear_index(self):
""")
self.conn.commit()

def add_index_entries(self, index_entries: List[BatchInferenceRecord]):
def add_index_entries(self, index_entries: List['BatchInferenceRecord']):
if index_entries:
self.cursor.executemany("""
INSERT INTO page_results (inference_s3_path, pdf_s3_path, page_num, round, start_index, length, finish_reason, error)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
""", [(entry.inference_s3_path, entry.pdf_s3_path, entry.page_num, entry.round, entry.start_index, entry.length, entry.finish_reason, entry.error) for entry in index_entries])
self.conn.commit()

def get_index_entries(self, pdf_s3_path: str) -> List[BatchInferenceRecord]:
def get_index_entries(self, pdf_s3_path: str) -> List['BatchInferenceRecord']:
self.cursor.execute("""
SELECT inference_s3_path, pdf_s3_path, page_num, round, start_index, length, finish_reason, error
FROM page_results
Expand Down Expand Up @@ -204,9 +210,9 @@ def add_pdf(self, s3_path: str, num_pages: int, status: str = 'pending') -> None
""", (s3_path, num_pages, status))
self.conn.commit()
except sqlite3.IntegrityError:
print(f"PDF with s3_path '{s3_path}' already exists.")
logger.warning(f"PDF with s3_path '{s3_path}' already exists.")

def update_pdf_statuses(self, status_updates: dict[str, str]) -> None:
def update_pdf_statuses(self, status_updates: Dict[str, str]) -> None:
"""
Update the status of multiple PDFs in the database.
Expand All @@ -220,7 +226,7 @@ def update_pdf_statuses(self, status_updates: dict[str, str]) -> None:
""", [(new_status, s3_path) for s3_path, new_status in status_updates.items()])
self.conn.commit()

def get_pdf(self, s3_path: str) -> Optional[PDFRecord]:
def get_pdf(self, s3_path: str) -> Optional['PDFRecord']:
self.cursor.execute("""
SELECT s3_path, num_pages, status
FROM pdfs
Expand All @@ -237,7 +243,7 @@ def get_pdf(self, s3_path: str) -> Optional[PDFRecord]:
)
return None

def get_pdfs_by_status(self, status: str) -> List[PDFRecord]:
def get_pdfs_by_status(self, status: str) -> List['PDFRecord']:
self.cursor.execute("""
SELECT s3_path, num_pages, status
FROM pdfs
Expand Down Expand Up @@ -334,7 +340,7 @@ def _write_batch_to_file(self, temp_file_path: str, batch_objects: List[Any]):
try:
workspace_s3.upload_file(temp_file_path, bucket, key)
except Exception as e:
print(f"Failed to upload {temp_file_path} to {output_path}: {e}")
logger.error(f"Failed to upload {temp_file_path} to {output_path}: {e}", exc_info=True)
else:
# Move the temp file to the output path
os.rename(temp_file_path, output_path)
Expand Down Expand Up @@ -457,7 +463,7 @@ def process_jsonl_content(inference_s3_path: str) -> List[DatabaseManager.BatchI
))

except Exception as e:
print(f"Error processing line in {inference_s3_path}: {e}")
logger.exception(f"Error processing line in {inference_s3_path}: {e}")
# Optionally, you might want to add an index entry indicating an error here

start_index += line_length # Increment by the number of bytes
Expand All @@ -474,7 +480,7 @@ def get_pdf_num_pages(s3_path: str) -> Optional[int]:
reader = PdfReader(tf.name)
return reader.get_num_pages()
except Exception as ex:
print(f"Warning, could not add {s3_path} due to {ex}")
logger.warning(f"Warning, could not add {s3_path} due to {ex}")

return None

Expand Down Expand Up @@ -510,7 +516,7 @@ def build_pdf_queries(s3_workspace: str, pdf: DatabaseManager.PDFRecord, cur_rou
else:
new_queries.append({**build_page_query(tf.name, pdf.s3_path, target_page_num, target_longest_image_dim, target_anchor_text_len), "round": cur_round})
except Exception as ex:
print(f"Warning, could not get batch inferences lines for {pdf.s3_path} due to {ex}")
logger.warning(f"Warning, could not get batch inferences lines for {pdf.s3_path} due to {ex}")

return new_queries

Expand Down Expand Up @@ -626,21 +632,22 @@ def get_current_round(s3_workspace: str) -> int:
pdf_s3 = pdf_session.client("s3")

db = DatabaseManager(args.workspace)
print(f"Loaded db at {db.db_path}")
logger.info(f"Loaded db at {db.db_path}")

if args.reindex:
db.clear_index()
logger.info("Cleared existing index.")

current_round = get_current_round(args.workspace)
print(f"Current round is {current_round}\n")
logger.info(f"Current round is {current_round}")

# One shared executor to rule them all
executor = ProcessPoolExecutor()

# If you have new PDFs, step one is to add them to the list
if args.add_pdfs:
if args.add_pdfs.startswith("s3://"):
print(f"Querying all PDFs at {args.add_pdfs}")
logger.info(f"Querying all PDFs at {args.add_pdfs}")

all_pdfs = expand_s3_glob(pdf_s3, args.add_pdfs)
print(f"Found {len(all_pdfs):,} total pdf paths")
Expand All @@ -651,52 +658,52 @@ def get_current_round(s3_workspace: str) -> int:
raise ValueError("add_pdfs argument needs to be either an s3 glob search path, or a local file contains pdf paths (one per line)")

all_pdfs = [pdf for pdf in all_pdfs if not db.pdf_exists(pdf)]
print(f"Need to import {len(all_pdfs):,} total new pdf paths")
logger.info(f"Need to import {len(all_pdfs):,} total new pdf paths")

future_to_path = {executor.submit(get_pdf_num_pages, s3_path): s3_path for s3_path in all_pdfs}
for future in tqdm(as_completed(future_to_path), total=len(future_to_path)):
for future in tqdm(as_completed(future_to_path), total=len(future_to_path), desc="Adding PDFs"):
s3_path = future_to_path[future]
num_pages = future.result()
if num_pages and not db.pdf_exists(s3_path):
db.add_pdf(s3_path, num_pages, "pending")

print("\n")
logger.info("Completed adding new PDFs.")

# Now build an index of all the pages that were processed within the workspace so far
print("Indexing all batch inference sent to this workspace")
logger.info("Indexing all batch inference sent to this workspace")
inference_output_paths = expand_s3_glob(workspace_s3, f"{args.workspace}/inference_outputs/*.jsonl")

inference_output_paths = {
s3_path: etag for s3_path, etag in inference_output_paths.items()
if not db.is_file_processed(s3_path, etag)
}

print(f"Found {len(inference_output_paths):,} new batch inference results to index")
logger.info(f"Found {len(inference_output_paths):,} new batch inference results to index")
future_to_path = {executor.submit(process_jsonl_content, s3_path): (s3_path, etag) for s3_path, etag in inference_output_paths.items()}

for future in tqdm(as_completed(future_to_path), total=len(future_to_path)):
for future in tqdm(as_completed(future_to_path), total=len(future_to_path), desc="Indexing Inference Results"):
s3_path, etag = future_to_path[future]

try:
inference_records = future.result()

db.delete_index_entries_by_inference_s3_path(s3_path)
db.add_index_entries(inference_records)
db.update_processed_file(s3_path, etag=etag)
except urllib3.exceptions.SSLError:
print(f"Cannot load inference file {s3_path} due to SSL error, will retry another time")

logger.warning(f"Cannot load inference file {s3_path} due to SSL error, will retry another time")
except Exception as e:
logger.exception(f"Failed to index inference file {s3_path}: {e}")

# Now query each pdf, if you have all of the pages needed (all pages present, error is null and finish_reason is stop), then you assemble it into a dolma document and output it
# If you don't have every page, or if you have pages with errors, then you output a new batch of inference items to use
if db.get_last_indexed_round() < current_round - 1:
print(f"WARNING: No new batch inference results found, you need to run batch inference on {args.workspace}/inference_inputs/round_{current_round - 1}")
logger.warning(f"WARNING: No new batch inference results found, you need to run batch inference on {args.workspace}/inference_inputs/round_{current_round - 1}")
potentially_done_pdfs = db.get_pdfs_by_status("pending")
elif args.skip_build_queries:
print(f"Skipping generating new batch inference files")
logger.info(f"Skipping generating new batch inference files")
potentially_done_pdfs = db.get_pdfs_by_status("pending")
else:
print(f"\nCreating batch inference files for new PDFs")
logger.info("Creating batch inference files for new PDFs")
pdf_list = list(db.get_pdfs_by_status("pending"))
pdf_iter = iter(pdf_list)
pending_futures = {}
Expand All @@ -706,7 +713,7 @@ def get_current_round(s3_workspace: str) -> int:
total_pdfs = len(pdf_list)
max_pending = 300

with tqdm(total=total_pdfs) as pbar:
with tqdm(total=total_pdfs, desc="Building Batch Queries") as pbar:
# Submit initial batch of futures
for _ in range(min(max_pending, total_pdfs)):
pdf = next(pdf_iter)
Expand Down Expand Up @@ -750,14 +757,14 @@ def get_current_round(s3_workspace: str) -> int:
new_inference_writer.close()

if lines_written > 0:
print(f"Added {lines_written:,} new batch inference requests")
logger.info(f"Added {lines_written:,} new batch inference requests")

# Now, finally, assemble any potentially done docs into dolma documents
print(f"\nAssembling potentially finished PDFs into Dolma documents at {args.workspace}/output")
logger.info(f"Assembling potentially finished PDFs into Dolma documents at {args.workspace}/output")
future_to_path = {executor.submit(build_dolma_doc, args.workspace, pdf): pdf for pdf in potentially_done_pdfs}
new_output_writer = BatchWriter(f"{args.workspace}/output", args.max_size_mb, after_flush=partial(mark_pdfs_done, args.workspace))

for future in tqdm(as_completed(future_to_path), total=len(future_to_path)):
for future in tqdm(as_completed(future_to_path), total=len(future_to_path), desc="Assembling Dolma Docs"):
pdf = future_to_path[future]
dolma_doc = future.result()

Expand All @@ -766,14 +773,14 @@ def get_current_round(s3_workspace: str) -> int:

new_output_writer.close()

print("\nFinal statistics:")
logger.info("Final statistics:")

# Output the number of PDFs in each status "pending" and "completed"
pending_pdfs = db.get_pdfs_by_status("pending")
completed_pdfs = db.get_pdfs_by_status("completed")

print(f"Pending PDFs: {len(pending_pdfs):,} ({sum(doc.num_pages for doc in pending_pdfs):,} pages)")
print(f"Completed PDFs: {len(completed_pdfs):,} ({sum(doc.num_pages for doc in completed_pdfs):,} pages)")
logger.info(f"Pending PDFs: {len(pending_pdfs):,} ({sum(doc.num_pages for doc in pending_pdfs):,} pages)")
logger.info(f"Completed PDFs: {len(completed_pdfs):,} ({sum(doc.num_pages for doc in completed_pdfs):,} pages)")

# For each round, outputs a report of how many pages were processed, how many had errors, and a breakdown by (error, finish_reason)
total_rounds = db.get_last_indexed_round() + 1
Expand All @@ -788,12 +795,12 @@ def get_current_round(s3_workspace: str) -> int:
results = db.cursor.fetchall()

total_pages = sum(count for count, _, _ in results)
print(f"\nInference Round {round_num} - {total_pages:,} pages processed:")
logger.info(f"Inference Round {round_num} - {total_pages:,} pages processed:")

for count, error, finish_reason in results:
error_str = error if error is not None else "None"
print(f" (error: {error_str}, finish_reason: {finish_reason}) -> {count:,} pages")
logger.info(f" (error: {error_str}, finish_reason: {finish_reason}) -> {count:,} pages")

print("\nWork finished, waiting for all workers to finish cleaning up")
logger.info("Work finished, waiting for all workers to finish cleaning up")
executor.shutdown(wait=True)
db.close()

0 comments on commit 45269fa

Please sign in to comment.