Skip to content

Commit

Permalink
Quicker results by limited workers via semaphore while still utilizin…
Browse files Browse the repository at this point in the history
…g gpu
  • Loading branch information
jakep-allenai committed Nov 12, 2024
1 parent 6154095 commit 4f2f4fd
Showing 1 changed file with 51 additions and 9 deletions.
60 changes: 51 additions & 9 deletions pdelfin/beakerpipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import aiohttp
import datetime
import tempfile
import re

from tqdm import tqdm
from io import BytesIO
Expand Down Expand Up @@ -73,7 +74,7 @@ async def build_page_query(local_pdf_path: str, page: int, target_longest_image_

# Allow the page rendering to process in the background while we get the anchor text (which blocks the main thread)
image_base64 = asyncio.to_thread(render_pdf_to_base64png, local_pdf_path, page, target_longest_image_dim=target_longest_image_dim)

# GET ANCHOR TEXT IS NOT THREAD SAFE!! Ahhhh..... don't try to do it
# and it's also CPU bound, so it needs to run in a process pool
loop = asyncio.get_running_loop()
Expand Down Expand Up @@ -287,7 +288,7 @@ async def process_pdf(args, pdf_s3_path: str):
logger.exception(f"Could not load page for {pdf_s3_path}, aborting document")
return None


# Build the document text and page spans
document_text = ""
pdf_page_spans = []
Expand Down Expand Up @@ -332,11 +333,14 @@ async def process_pdf(args, pdf_s3_path: str):
return dolma_doc


async def worker(args, queue):
async def worker(args, queue, semaphore):
while True:
[work_hash, pdfs] = await queue.get()

try:
# Wait until allowed to proceed
await semaphore.acquire()

dolma_docs = await asyncio.gather(*[process_pdf(args, pdf) for pdf in pdfs])
dolma_docs = [doc for doc in dolma_docs if doc is not None]

Expand Down Expand Up @@ -372,7 +376,7 @@ async def worker(args, queue):
queue.task_done()


async def sglang_server_task(args):
async def sglang_server_task(args, semaphore):
model_cache_dir = os.path.join(os.path.expanduser('~'), '.cache', 'pdelfin', 'model')
# TODO cache locally
#download_directory(args.model, model_cache_dir)
Expand All @@ -390,20 +394,53 @@ async def sglang_server_task(args):

proc = await asyncio.create_subprocess_exec(
"python3",

"-m", "sglang.launch_server",
"--model-path", model_cache_dir,
"--chat-template", args.model_chat_template,
"--context-length", str(args.model_max_context),
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)

# Make really sure we kill this subprocess on exit
# Make sure we kill this subprocess on exit
def _kill_proc():
proc.terminate()

atexit.register(_kill_proc)

last_queue_req = None # To track transitions
async def process_line(line):
# Parse the line and update semaphore if necessary
match = re.search(r'#running-req: (\d+), #queue-req: (\d+)', line)
if match:
logger.info(line)
running_req = int(match.group(1))
queue_req = int(match.group(2))

nonlocal last_queue_req
if last_queue_req is not None and last_queue_req != 0 and queue_req == 0:
# Release the semaphore when queue_req transitions from non-zero to zero
if semaphore.locked():
semaphore.release()
logger.info("Semaphore released, allowing a worker to proceed.")

last_queue_req = queue_req

async def read_stream(stream):
while True:
line = await stream.readline()
if not line:
break
line = line.decode('utf-8').rstrip()
await process_line(line)

# Start tasks to read stdout and stderr
stdout_task = asyncio.create_task(read_stream(proc.stdout))
stderr_task = asyncio.create_task(read_stream(proc.stderr))

await proc.wait()
await stdout_task
await stderr_task


async def sglang_server_ready():
Expand Down Expand Up @@ -463,7 +500,13 @@ async def main():
if args.pdfs:
await populate_pdf_work_queue(args)

sglang_server = asyncio.create_task(sglang_server_task(args))
# Create a semaphore to control worker access
# We only allow one worker to move forward with requests, until the server has no more requests in its queue
# This lets us get full utilization by having many workers, but also to be outputting dolma docs as soon as possible
# As soon as one worker is no longer saturating the gpu, the next one can start sending requests
semaphore = asyncio.Semaphore(1)

sglang_server = asyncio.create_task(sglang_server_task(args, semaphore))

work_queue = await load_pdf_work_queue(args)
logger.info(f"Work queue prepared with {work_queue.qsize()} items")
Expand All @@ -473,7 +516,7 @@ async def main():
# Create worker tasks to process the queue concurrently.
worker_tasks = []
for i in range(args.workers):
task = asyncio.create_task(worker(args, work_queue))
task = asyncio.create_task(worker(args, work_queue, semaphore))
worker_tasks.append(task)

# Wait for the queue to be fully processed
Expand Down Expand Up @@ -501,4 +544,3 @@ async def main():
# TODO
# Possible future addon, in beaker, discover other nodes on this same job
# Send them a message when you take a work item off the queue

0 comments on commit 4f2f4fd

Please sign in to comment.