Skip to content

Commit

Permalink
Work queue coallescing
Browse files Browse the repository at this point in the history
  • Loading branch information
jakep-allenai committed Nov 7, 2024
1 parent 57186c7 commit b15bff6
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 14 deletions.
55 changes: 41 additions & 14 deletions pdelfin/beakerpipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def compute_workgroup_sha1(work_group: list[str]) -> str:
sha1.update(pdf.encode('utf-8'))
return sha1.hexdigest()


if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Manager for running millions of PDFs through a batch inference pipeline')
parser.add_argument('workspace', help='The S3 path where work will be done e.g., s3://bucket/prefix/')
Expand Down Expand Up @@ -76,12 +77,18 @@ def compute_workgroup_sha1(work_group: list[str]) -> str:

all_pdfs = set(all_pdfs)
logger.info(f"Found {len(all_pdfs):,} total pdf paths")

existing_lines = download_zstd_csv(workspace_s3, index_file_s3_path)

# Parse existing work items into groups
existing_groups = [line.strip().split(",") for line in existing_lines if line.strip()]
existing_pdf_set = set(pdf for group in existing_groups for pdf in group)
existing_groups = {}
for line in existing_lines:
if line.strip():
parts = line.strip().split(",")
group_hash = parts[0]
group_pdfs = parts[1:]
existing_groups[group_hash] = group_pdfs
existing_pdf_set = set(pdf for group_pdfs in existing_groups.values() for pdf in group_pdfs)

logger.info(f"Loaded {len(existing_pdf_set):,} existing pdf paths from the workspace")

Expand All @@ -96,18 +103,22 @@ def compute_workgroup_sha1(work_group: list[str]) -> str:
for pdf in sorted(new_pdfs): # Sort for consistency
current_group.append(pdf)
if len(current_group) == args.group_size:
new_groups.append(current_group)
group_hash = compute_workgroup_sha1(current_group)
new_groups.append((group_hash, current_group))
current_group = []
if current_group:
new_groups.append(current_group)
group_hash = compute_workgroup_sha1(current_group)
new_groups.append((group_hash, current_group))

logger.info(f"Created {len(new_groups):,} new work groups")

# Combine existing groups with new groups
combined_groups = existing_groups + new_groups
combined_groups = existing_groups.copy()
for group_hash, group_pdfs in new_groups:
combined_groups[group_hash] = group_pdfs

# Prepare lines to write back
combined_lines = [",".join(group) for group in combined_groups]
combined_lines = [",".join([group_hash] + group_pdfs) for group_hash, group_pdfs in combined_groups.items()]

# Upload the combined work items back to S3
if new_groups:
Expand All @@ -119,9 +130,9 @@ def compute_workgroup_sha1(work_group: list[str]) -> str:
# If there is a beaker flag, then your job is to trigger this script with N replicas on beaker
# If not, then your job is to do the actual work

# Donwload the model from the best place available
# Download the model from the best place available
model_cache_dir = os.path.join(os.path.expanduser('~'), '.cache', 'pdelfin', 'model')
#download_directory(args.model, model_cache_dir)
download_directory(args.model, model_cache_dir)

# Start up the sglang server
sglang_process = subprocess.Popen([
Expand All @@ -131,7 +142,6 @@ def compute_workgroup_sha1(work_group: list[str]) -> str:
"--context-length", str(args.model_max_context),
])


# Register atexit function and signal handlers to guarantee process termination
def terminate_processes():
print("Terminating child processes...")
Expand All @@ -153,11 +163,28 @@ def signal_handler(sig, frame):
signal.signal(signal.SIGTERM, signal_handler)

# Read in the work queue from s3
work_queue = download_zstd_csv(workspace_s3, index_file_s3_path)
work_queue = {compute_workgroup_sha1(pdfs): pdfs for pdfs in work_queue}
work_queue_lines = download_zstd_csv(workspace_s3, index_file_s3_path)
work_queue = {}
for line in work_queue_lines:
if line.strip():
parts = line.strip().split(",")
group_hash = parts[0]
group_pdfs = parts[1:]
work_queue[group_hash] = group_pdfs

# Read in the done items from the s3 workspace
done_work_items = expand_s3_glob(workspace_s3, f"{args.workspace}/dolma_documents/*.jsonl")
done_work_items = expand_s3_glob(workspace_s3, f"{args.workspace}/dolma_documents/output_*.jsonl")
done_work_hashes = set()
for item in done_work_items:
filename = os.path.basename(item)
if filename.startswith('output_') and filename.endswith('.jsonl'):
group_hash = filename[len('output_'):-len('.jsonl')]
done_work_hashes.add(group_hash)

remaining_work_hashes = set(work_queue.keys()) - done_work_hashes
remaining_work_queue = {hash: work_queue[hash] for hash in remaining_work_hashes}

logger.info(f"Remaining work items: {len(remaining_work_queue)}")

# TODO
# Spawn up to N workers to do:
Expand All @@ -178,4 +205,4 @@ def signal_handler(sig, frame):
logger.error(f"Sglang server exited with code {sglang_process.returncode} exiting.")
except KeyboardInterrupt:
logger.info("Got keyboard interrupt, exiting everything")
sys.exit(1)
sys.exit(1)
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ dependencies = [
"markdown2",
"filelock",
"orjson",
"requests",
"zstandard",
]
license = {file = "LICENSE"}

Expand Down

0 comments on commit b15bff6

Please sign in to comment.