Skip to content

Commit

Permalink
Small fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
jakep-allenai committed Jan 10, 2025
1 parent 2190f61 commit 0d1fc08
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 27 deletions.
4 changes: 2 additions & 2 deletions pdelfin/beakerpipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -718,7 +718,7 @@ def submit_beaker_job(args):

def print_stats(args):
LONG_CONTEXT_THRESHOLD = 32768

# Get total work items and completed items
index_file_s3_path = os.path.join(args.workspace, "work_index_list.csv.zstd")
output_glob = os.path.join(args.workspace, "results", "*.jsonl")
Expand Down Expand Up @@ -858,7 +858,7 @@ async def main():
# Beaker/job running stuff
parser.add_argument('--beaker', action='store_true', help='Submit this job to beaker instead of running locally')
parser.add_argument('--beaker_workspace', help='Beaker workspace to submit to', default='ai2/pdelfin')
parser.add_argument('--beaker_cluster', help='Beaker clusters you want to run on', default=["ai2/jupiter-cirrascale-2", "ai2/pluto-cirrascale", "ai2/neptune-cirrascale", "ai2/saturn-cirrascale", "ai2/augusta-google-1"])
parser.add_argument('--beaker_cluster', help='Beaker clusters you want to run on', default=["ai2/jupiter-cirrascale-2", "ai2/ceres-cirrascale", "ai2/neptune-cirrascale", "ai2/saturn-cirrascale", "ai2/augusta-google-1"])
parser.add_argument('--beaker_gpus', type=int, default=1, help="Number of gpu replicas to run")
parser.add_argument('--beaker_priority', type=str, default="normal", help="Beaker priority level for the job")
args = parser.parse_args()
Expand Down
116 changes: 93 additions & 23 deletions pdelfin/data/convertsilver_birr.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,14 @@
from concurrent.futures import ProcessPoolExecutor, as_completed
import sys
import logging
import tempfile
import os

import smart_open
from cached_path import cached_path
import boto3
from pdelfin.prompts import build_finetuning_prompt
from pdelfin.prompts.anchor import get_anchor_text
from pdelfin.data.renderpdf import render_pdf_to_base64png

# Import Plotly for plotting
import plotly.express as px
Expand All @@ -31,6 +34,38 @@ def is_s3_path(path):
return str(path).startswith('s3://')


def download_pdf_from_s3(s3_path: str, pdf_profile: str) -> str:
"""
Downloads a PDF file from S3 to a temporary local file and returns the local file path.
Args:
s3_path (str): S3 path in the format s3://bucket/key
pdf_profile (str): The name of the boto3 profile to use.
Returns:
str: Path to the downloaded PDF file in the local filesystem.
"""
# Parse the bucket and key from the s3_path
# s3_path format: s3://bucket_name/some/folder/file.pdf
path_without_scheme = s3_path.split('s3://', 1)[1]
bucket_name, key = path_without_scheme.split('/', 1)

# Create a session with the specified profile or default
session = boto3.Session(profile_name=pdf_profile) if pdf_profile else boto3.Session()
s3_client = session.client('s3')

# Create a temporary local file
tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.pdf')
tmp_file.close() # We only want the path and not keep it locked

local_path = tmp_file.name

logging.info(f"Downloading PDF from {s3_path} to {local_path} using profile {pdf_profile}")
s3_client.download_file(bucket_name, key, local_path)

return local_path


def transform_json_object(obj):
"""
Transform a single JSON object by extracting and renaming specific fields.
Expand All @@ -39,7 +74,7 @@ def transform_json_object(obj):
obj (dict): Original JSON object.
Returns:
dict: Transformed JSON object.
dict or None: Transformed JSON object, or None if there's an error.
"""
try:
transformed = {
Expand All @@ -54,14 +89,15 @@ def transform_json_object(obj):
return None


def process_file(input_file: str, output_file: str, rewrite_prompt_str: bool):
def process_file(input_file: str, output_file: str, rewrite_prompt_str: bool, pdf_profile: str):
"""
Process a single JSONL file: read, transform, and write to output.
Args:
input_file (str): Path or URL to the input JSONL file.
output_file (str): Path or URL to the output JSONL file.
rewrite_prompt_str (bool): Flag to rewrite the prompt string.
pdf_profile (str): Boto3 profile to use when fetching PDFs from S3.
"""
processed_count = 0
error_count = 0
Expand All @@ -85,25 +121,49 @@ def process_file(input_file: str, output_file: str, rewrite_prompt_str: bool):
transformed = transform_json_object(obj)

if transformed is not None and rewrite_prompt_str:
# We look for RAW_TEXT_START ... RAW_TEXT_END in the existing content
pattern = r"RAW_TEXT_START\s*\n(.*?)\nRAW_TEXT_END"

# Use re.DOTALL to ensure that the dot matches newline characters
match = re.search(pattern, transformed["chat_messages"][0]["content"][0]["text"], re.DOTALL)

if match:
raw_page_text = match.group(1).strip()

# Ok, now we want to try to see if it's better if we recalculate the anchor text
# We found raw page text, but we'll attempt to regenerate it
goldkey = obj["custom_id"]
s3_path = goldkey[:goldkey.rindex("-")]
page = int(goldkey[goldkey.rindex("-") + 1:])
# goldkey might look like: "s3://bucket/path/to/file.pdf-23"
# s3_path = everything up to the last dash
# page = everything after the dash
try:
s3_path = goldkey[:goldkey.rindex("-")]
page = int(goldkey[goldkey.rindex("-") + 1:])
except (ValueError, IndexError) as e:
logging.error(f"Could not parse the page number from custom_id {goldkey}: {e}")
error_count += 1
continue

# If the path is an S3 path, download to a local temp file; else assume local
if is_s3_path(s3_path):
local_pdf_path = download_pdf_from_s3(s3_path, pdf_profile)
else:
local_pdf_path = s3_path

# Recalculate the anchor text
raw_page_text = get_anchor_text(
local_pdf_path,
page,
pdf_engine="pdfreport",
target_length=6000
)

image_base64 = render_pdf_to_base64png(local_pdf_path, page, 1024)

# Save the pdf to a temporary cache folder
local_pdf_path = cached_path(s3_path, quiet=True)

raw_page_text = get_anchor_text(local_pdf_path, page, pdf_engine="pdfreport", target_length=6000)
transformed["chat_messages"][0]["content"][0]["text"] = build_finetuning_prompt(raw_page_text)
transformed["chat_messages"][0]["content"][1]["image_url"]["url"] = f"data:image/png;base64,{image_base64}"

# Clean up the temp PDF file if it was downloaded
if is_s3_path(s3_path):
try:
os.remove(local_pdf_path)
except OSError as remove_err:
logging.error(f"Failed to remove temporary PDF file {local_pdf_path}: {remove_err}")

if transformed is not None:
prompt_text = transformed["chat_messages"][0]["content"][0]["text"]
Expand Down Expand Up @@ -175,8 +235,6 @@ def list_input_files(input_dir):
list: List of input file paths.
"""
if is_s3_path(input_dir):
# Use smart_open's s3 functionality to list files
import boto3
import fnmatch

# Parse bucket and prefix
Expand All @@ -191,19 +249,18 @@ def list_input_files(input_dir):
prefix = ''
pattern = path_and_pattern

# Set up S3 resource and bucket
s3 = boto3.resource('s3')
# Use a Boto3 session (no specific PDF profile needed here if only listing)
session = boto3.Session()
s3 = session.resource('s3')
bucket = s3.Bucket(bucket_name)

# Get all objects and filter them manually based on the pattern
files = []
for obj in bucket.objects.filter(Prefix=prefix):
if fnmatch.fnmatch(obj.key, f'{prefix}{pattern}'):
files.append(f's3://{bucket_name}/{obj.key}')

return files
else:
# Local path handling (with glob pattern)
input_dir_path = Path(input_dir)
return [str(p) for p in input_dir_path.glob('*.jsonl')]

Expand All @@ -216,8 +273,8 @@ def main():
parser.add_argument(
'--rewrite_finetuning_prompt',
action='store_true',
default=False,
help="Rewrites the input prompt from standard OPENAI instruction format into our finetuned format"
default=True,
help="Rewrite the input prompt from a standard OPENAI instruction format into a finetuned format."
)
parser.add_argument(
'input_dir',
Expand All @@ -235,6 +292,13 @@ def main():
default=20,
help='Number of parallel jobs to run (default: 20).'
)
parser.add_argument(
'--pdf_profile',
type=str,
default=None,
help='Boto3 profile to use for downloading PDFs from S3. Defaults to the default session.'
)

args = parser.parse_args()

input_dir = args.input_dir.rstrip('/')
Expand All @@ -260,7 +324,13 @@ def main():
all_prompt_lengths = []
with ProcessPoolExecutor(max_workers=max_jobs) as executor:
future_to_file = {
executor.submit(process_file, input_file, output_file, args.rewrite_finetuning_prompt): input_file
executor.submit(
process_file,
input_file,
output_file,
args.rewrite_finetuning_prompt,
args.pdf_profile
): input_file
for input_file, output_file in tasks
}

Expand Down
2 changes: 1 addition & 1 deletion pdelfin/data/runopenaibatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def download_batch_result(batch_id, output_folder):
return batch_id, False

if batch_data.output_file_id is None:
print(f"WARNING: {batch_id} is completed, but not output file was generated")
print(f"WARNING: {batch_id} is completed, but no output file was generated")
return batch_id, False

print(f"Downloading batch data for {batch_id}")
Expand Down
11 changes: 10 additions & 1 deletion tests/test_anchor.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,4 +196,13 @@ def testFastMediaBoxMatchesPyPdf(self):
pypdfpage = reader.pages[page_num - 1]

self.assertAlmostEqual(w1, pypdfpage.mediabox.width, places=3)
self.assertAlmostEqual(h1, pypdfpage.mediabox.height, places=3)
self.assertAlmostEqual(h1, pypdfpage.mediabox.height, places=3)

class TestOutputSamplePage(unittest.TestCase):
def testTobaccoPaper(self):
local_pdf_path = os.path.join(os.path.dirname(__file__), "gnarly_pdfs", "tobacco_missed_tokens_pg1.pdf")
anchor_text = get_anchor_text(local_pdf_path, 1, 'pdfreport', target_length=6000)

print("")
print(anchor_text)
print("")

0 comments on commit 0d1fc08

Please sign in to comment.