From 278422b8ff941e4fc1658242272424f6739808d5 Mon Sep 17 00:00:00 2001 From: Jake Poznanski Date: Fri, 15 Nov 2024 10:03:26 -0800 Subject: [PATCH] Fixing one max context issue --- pdelfin/beakerpipeline.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/pdelfin/beakerpipeline.py b/pdelfin/beakerpipeline.py index 0e44fc6..b176b22 100644 --- a/pdelfin/beakerpipeline.py +++ b/pdelfin/beakerpipeline.py @@ -245,6 +245,7 @@ async def load_pdf_work_queue(args) -> asyncio.Queue: # Determine remaining work remaining_work_hashes = set(work_queue) - done_work_hashes + #remaining_work_hashes = set(["0e779f21fbb75d38ed4242c7e5fe57fa9a636bac"]) remaining_work_queue = { hash_: work_queue[hash_] for hash_ in remaining_work_hashes @@ -280,6 +281,7 @@ async def process_page(args, session: aiohttp.ClientSession, worker_id: int, pdf MAX_RETRIES = 3 exponential_backoffs = 0 + local_anchor_text_len = args.target_anchor_text_len attempt = 0 await tracker.track_work(worker_id, f"{pdf_s3_path}-{page_num}", "started") @@ -288,7 +290,7 @@ async def process_page(args, session: aiohttp.ClientSession, worker_id: int, pdf pdf_local_path, page_num, args.target_longest_image_dim, - args.target_anchor_text_len + local_anchor_text_len ) try: @@ -296,6 +298,11 @@ async def process_page(args, session: aiohttp.ClientSession, worker_id: int, pdf response.raise_for_status() base_response_data = await response.json() + + if base_response_data["usage"]["total_tokens"] > args.model_max_context: + local_anchor_text_len = max(1, local_anchor_text_len // 2) + logger.info(f"Reducing anchor text len to {local_anchor_text_len} for {pdf_s3_path}-{page_num}") + raise ValueError(f"Response exceeded model_max_context, cannot use this response") metrics.add_metrics(sglang_input_tokens=base_response_data["usage"].get("prompt_tokens", 0), sglang_output_tokens=base_response_data["usage"].get("completion_tokens", 0)) @@ -328,6 +335,9 @@ async def process_page(args, session: aiohttp.ClientSession, worker_id: int, pdf except json.JSONDecodeError as e: logger.warning(f"JSON decode error on attempt {attempt} for {pdf_s3_path}-{page_num}: {e}") attempt += 1 + except ValueError as e: + logger.warning(f"ValueError on attempt {attempt} for {pdf_s3_path}-{page_num}: {type(e)} - {e}") + attempt += 1 except Exception as e: logger.warning(f"Unexpected error on attempt {attempt} for {pdf_s3_path}-{page_num}: {type(e)} - {e}") attempt += 1 @@ -493,7 +503,10 @@ async def sglang_server_task(args, semaphore): "-m", "sglang.launch_server", "--model-path", model_cache_dir, "--chat-template", args.model_chat_template, - "--context-length", str(args.model_max_context), + + # TODO Had to comment this out, I thought it would be good to enforce a context limit on the server side, but it causes crashes + #"--context-length", str(args.model_max_context), + "--port", str(SGLANG_SERVER_PORT), "--log-level-http", "warning", stdout=asyncio.subprocess.PIPE,