generated from allenai/python-package-template
-
Notifications
You must be signed in to change notification settings - Fork 484
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #62 from allenai/amanr/bench
Added Gemini and Claude runners with a viewer.
- Loading branch information
Showing
10 changed files
with
881 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,206 @@ | ||
import re | ||
from dataclasses import dataclass | ||
from typing import Optional | ||
|
||
|
||
# This is the prompt we use for getting chat gpt 4o to convert documents into our silver training data | ||
def build_openai_silver_data_prompt(base_text: str) -> str: | ||
return ( | ||
f"Below is the image of one page of a PDF document, as well as some raw textual content that was previously extracted for it that includes position information for each image and block of text (The origin [0x0] of the coordinates is in the lower left corner of the image). " | ||
f"Just return the plain text representation of this document as if you were reading it naturally.\n" | ||
f"Turn equations into a LaTeX representation, and tables into markdown format. Remove the headers and footers, but keep references and footnotes.\n" | ||
f"Read any natural handwriting.\n" | ||
f"This is likely one page out of several in the document, so be sure to preserve any sentences that come from the previous page, or continue onto the next page, exactly as they are.\n" | ||
f"If there is no text at all that you think you should read, you can output null.\n" | ||
f"Do not hallucinate.\n" | ||
f"RAW_TEXT_START\n{base_text}\nRAW_TEXT_END" | ||
) | ||
|
||
|
||
@dataclass(frozen=True) | ||
class PageResponse: | ||
primary_language: Optional[str] | ||
is_rotation_valid: bool | ||
rotation_correction: int | ||
is_table: bool | ||
is_diagram: bool | ||
natural_text: Optional[str] | ||
|
||
def __post_init__(self): | ||
# Validate rotation_correction is one of the allowed values | ||
if self.rotation_correction not in {0, 90, 180, 270}: | ||
raise ValueError("rotation_correction must be one of [0, 90, 180, 270].") | ||
|
||
# Type checks | ||
if not isinstance(self.primary_language, (str, type(None))): | ||
raise TypeError("primary_language must be of type Optional[str].") | ||
if not isinstance(self.is_rotation_valid, bool): | ||
raise TypeError("is_rotation_valid must be of type bool.") | ||
if not isinstance(self.rotation_correction, int): | ||
raise TypeError("rotation_correction must be of type int.") | ||
if not isinstance(self.is_table, bool): | ||
raise TypeError("is_table must be of type bool.") | ||
if not isinstance(self.is_diagram, bool): | ||
raise TypeError("is_diagram must be of type bool.") | ||
if not isinstance(self.natural_text, (str, type(None))): | ||
raise TypeError("natural_text must be of type Optional[str].") | ||
|
||
|
||
def openai_response_format_schema() -> dict: | ||
return { | ||
"type": "json_schema", | ||
"json_schema": { | ||
"name": "page_response", | ||
"schema": { | ||
"type": "object", | ||
"properties": { | ||
"primary_language": { | ||
"type": ["string", "null"], | ||
"description": "The primary language of the text using two-letter codes or null if there is no text at all that you think you should read.", | ||
}, | ||
"is_rotation_valid": { | ||
"type": "boolean", | ||
"description": "Is this page oriented correctly for reading? Answer only considering the textual content, do not factor in the rotation of any charts, tables, drawings, or figures.", | ||
}, | ||
"rotation_correction": { | ||
"type": "integer", | ||
"description": "Indicates the degree of clockwise rotation needed if the page is not oriented correctly.", | ||
"enum": [0, 90, 180, 270], | ||
"default": 0, | ||
}, | ||
"is_table": { | ||
"type": "boolean", | ||
"description": "Indicates if the majority of the page content is in tabular format.", | ||
}, | ||
"is_diagram": { | ||
"type": "boolean", | ||
"description": "Indicates if the majority of the page content is a visual diagram.", | ||
}, | ||
"natural_text": { | ||
"type": ["string", "null"], | ||
"description": "The natural text content extracted from the page.", | ||
}, | ||
}, | ||
"additionalProperties": False, | ||
"required": [ | ||
"primary_language", | ||
"is_rotation_valid", | ||
"rotation_correction", | ||
"is_table", | ||
"is_diagram", | ||
"natural_text", | ||
], | ||
}, | ||
"strict": True, | ||
}, | ||
} | ||
|
||
|
||
def claude_response_format_schema() -> dict: | ||
return ( | ||
{ | ||
"name": "page_response", | ||
"description": "Extracts text from pdf's.", | ||
"input_schema": { | ||
"type": "object", | ||
"properties": { | ||
"primary_language": { | ||
"type": ["string", "null"], | ||
"description": "The primary language of the text using two-letter codes or null if there is no text at all that you think you should read.", | ||
}, | ||
"is_rotation_valid": { | ||
"type": "boolean", | ||
"description": "Is this page oriented correctly for reading? Answer only considering the textual content, do not factor in the rotation of any charts, tables, drawings, or figures.", | ||
}, | ||
"rotation_correction": { | ||
"type": "integer", | ||
"description": "Indicates the degree of clockwise rotation needed if the page is not oriented correctly.", | ||
"enum": [0, 90, 180, 270], | ||
"default": 0, | ||
}, | ||
"is_table": { | ||
"type": "boolean", | ||
"description": "Indicates if the majority of the page content is in tabular format.", | ||
}, | ||
"is_diagram": { | ||
"type": "boolean", | ||
"description": "Indicates if the majority of the page content is a visual diagram.", | ||
}, | ||
"natural_text": { | ||
"type": ["string", "null"], | ||
"description": "The natural text content extracted from the page.", | ||
}, | ||
}, | ||
"required": [ | ||
"primary_language", | ||
"is_rotation_valid", | ||
"rotation_correction", | ||
"is_table", | ||
"is_diagram", | ||
"natural_text", | ||
], | ||
}, | ||
}, | ||
) | ||
|
||
|
||
def gemini_response_format_schema() -> dict: | ||
return ( | ||
{ | ||
"type": "OBJECT", | ||
"properties": { | ||
"primary_language": { | ||
"type": "STRING", | ||
"description": "The primary language of the text using two-letter codes or null if there is no text at all that you think you should read.", | ||
}, | ||
"is_rotation_valid": { | ||
"type": "BOOL", | ||
"description": "Is this page oriented correctly for reading? Answer only considering the textual content, do not factor in the rotation of any charts, tables, drawings, or figures.", | ||
}, | ||
"rotation_correction": { | ||
"type": "INTEGER", | ||
"enum": [0, 90, 180, 270], | ||
"description": "Indicates the degree of clockwise rotation needed if the page is not oriented correctly.", | ||
}, | ||
"is_table": {"type": "BOOL", "description": "Indicates if the majority of the page content is in tabular format."}, | ||
"is_diagram": {"type": "BOOL", "description": "Indicates if the majority of the page content is a visual diagram."}, | ||
"natural_text": {"type": "STRING", "description": "The natural text content extracted from the page."}, | ||
}, | ||
"required": ["primary_language", "is_rotation_valid", "rotation_correction", "is_table", "is_diagram", "natural_text"], | ||
"propertyOrdering": ["primary_language", "is_rotation_valid", "rotation_correction", "is_table", "is_diagram", "natural_text"], | ||
}, | ||
) | ||
|
||
|
||
def build_find_difference_prompt(base_text: str) -> str: | ||
return ( | ||
f"Below is an image of a document page, along with raw textual content previously extracted using different models." | ||
f"Your goal is to carefully identify the differences between the extracted texts from both models and determine which one is more accurate by comparing them with the image." | ||
f"Only return the differences and specify which model extracted the text with higher accuracy.\n" | ||
f"Do not hallucinate.\n" | ||
f"RAW_TEXT_START\n{base_text}\nRAW_TEXT_END" | ||
) | ||
|
||
|
||
# This is a base prompt that will be used for training and running the fine tuned model | ||
# It's simplified from the prompt which was used to generate the silver data, and can change from dataset to dataset | ||
def build_finetuning_prompt(base_text: str) -> str: | ||
return ( | ||
f"Below is the image of one page of a document, as well as some raw textual content that was previously extracted for it. " | ||
f"Just return the plain text representation of this document as if you were reading it naturally.\n" | ||
f"Do not hallucinate.\n" | ||
f"RAW_TEXT_START\n{base_text}\nRAW_TEXT_END" | ||
) | ||
|
||
|
||
# Extracts the anchor text component from an existing prompt string | ||
def extract_raw_text(prompt: str) -> str: | ||
pattern = r"RAW_TEXT_START\s*\n(.*?)\nRAW_TEXT_END" | ||
|
||
# Use re.DOTALL to ensure that the dot matches newline characters | ||
match = re.search(pattern, prompt, re.DOTALL) | ||
|
||
if match: | ||
return match.group(1).strip() | ||
else: | ||
raise ValueError("Prompt does not contain raw text") |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
import json | ||
import os | ||
|
||
from anthropic import Anthropic | ||
|
||
from olmocr.data.renderpdf import render_pdf_to_base64png | ||
from olmocr.prompts.anchor import get_anchor_text | ||
from prompts import ( | ||
build_openai_silver_data_prompt, | ||
claude_response_format_schema, | ||
) | ||
|
||
|
||
def run_claude(pdf_path: str, page_num: int = 1, model: str = "claude-3-7-sonnet-20250219", temperature: float = 0.1) -> str: | ||
""" | ||
Convert page of a PDF file to markdown using Claude OCR. | ||
This function renders the specified page of the PDF to an image, runs OCR on that image, | ||
and returns the OCR result as a markdown-formatted string. | ||
Args: | ||
pdf_path (str): The local path to the PDF file. | ||
page_num (int): The page number to process (starting from 1). | ||
model (str): The Claude model to use. | ||
temperature (float): The temperature parameter for generation. | ||
Returns: | ||
str: The OCR result in markdown format. | ||
""" | ||
|
||
image_base64 = render_pdf_to_base64png(pdf_path, page_num=page_num, target_longest_image_dim=2048) | ||
anchor_text = get_anchor_text(pdf_path, page_num, pdf_engine="pdfreport") | ||
client = Anthropic(api_key=os.getenv("ANTHROPIC_API_KEY")) | ||
response = client.messages.create( | ||
model=model, | ||
max_tokens=3000, | ||
temperature=temperature, | ||
# system=system_prompt, | ||
tools=claude_response_format_schema(), | ||
messages=[ | ||
{ | ||
"role": "user", | ||
"content": [ | ||
{"type": "image", "source": {"type": "base64", "media_type": "image/png", "data": image_base64}}, | ||
{ | ||
"type": "text", | ||
"text": f"{build_openai_silver_data_prompt(anchor_text)}. Use the page_response tool to respond. If the propeties are true, then extract the text from them and respond in natural_text.", | ||
}, | ||
], | ||
} | ||
], | ||
) | ||
|
||
json_sentiment = None | ||
for content in response.content: | ||
if content.type == "tool_use" and content.name == "page_response": | ||
json_sentiment = content.input | ||
break | ||
|
||
if json_sentiment: | ||
response = json.dumps(json_sentiment, indent=2) | ||
return response |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
import base64 | ||
import os | ||
|
||
from google.ai import generativelanguage as glm | ||
from google.api_core import client_options | ||
|
||
from olmocr.data.renderpdf import render_pdf_to_base64png | ||
from olmocr.prompts.anchor import get_anchor_text | ||
from prompts import ( # gemini_response_format_schema, | ||
build_openai_silver_data_prompt, | ||
) | ||
|
||
|
||
def run_gemini(pdf_path: str, page_num: int = 1, model: str = "gemini-1.5-pro", temperature: float = 0.1) -> str: | ||
""" | ||
Convert page of a PDF file to markdown using Gemini's vision capabilities. | ||
This function renders the specified page of the PDF to an image, runs OCR on that image, | ||
and returns the OCR result as a markdown-formatted string. | ||
Args: | ||
pdf_path (str): The local path to the PDF file. | ||
page_num (int): The page number to process (starting from 1). | ||
model (str): The Gemini model to use. | ||
temperature (float): The temperature parameter for generation. | ||
Returns: | ||
str: The OCR result in markdown format. | ||
""" | ||
image_base64 = render_pdf_to_base64png(pdf_path, page_num=page_num, target_longest_image_dim=2048) | ||
anchor_text = get_anchor_text(pdf_path, page_num, pdf_engine="pdfreport") | ||
api_key = os.getenv("GEMINI_API_KEY") | ||
client = glm.GenerativeServiceClient( | ||
client_options=client_options.ClientOptions( | ||
api_key=api_key, | ||
), | ||
) | ||
|
||
image_part = glm.Part(inline_data=glm.Blob(mime_type="image/png", data=base64.b64decode(image_base64))) | ||
|
||
text_part = glm.Part(text=f"""{build_openai_silver_data_prompt(anchor_text)}""") | ||
generation_config = glm.GenerationConfig( | ||
temperature=temperature, | ||
top_p=1.0, | ||
top_k=32, | ||
max_output_tokens=4096, | ||
) | ||
# response_schema = gemini_response_format_schema() | ||
request = glm.GenerateContentRequest( | ||
model=f"models/{model}", | ||
contents=[glm.Content(parts=[image_part, text_part])], | ||
generation_config=generation_config, | ||
) | ||
|
||
# request = glm.GenerateContentRequest( | ||
# model=f"models/{model}", | ||
# contents=[glm.Content(parts=[image_part, text_part])], | ||
# generation_config=generation_config, | ||
# tools=[ | ||
# glm.Tool( | ||
# function_declarations=[ | ||
# glm.FunctionDeclaration( | ||
# name="page_response", | ||
# parameters=response_schema | ||
# ) | ||
# ] | ||
# ) | ||
# ], | ||
# tool_config=glm.ToolConfig( | ||
# function_calling_config=glm.FunctionCallingConfig( | ||
# mode="any", | ||
# allowed_function_names=["page_response"] | ||
# ) | ||
# ) | ||
# ) | ||
|
||
response = client.generate_content(request) | ||
result = response.candidates[0].content.parts[0].text | ||
return result |
Oops, something went wrong.