Skip to content

Commit

Permalink
Formatting fix
Browse files Browse the repository at this point in the history
  • Loading branch information
jakep-allenai committed Feb 14, 2025
1 parent 0dcdbcc commit 32aa359
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 69 deletions.
56 changes: 26 additions & 30 deletions olmocr/train/hf/convertjsontoparquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,23 +8,24 @@
# The url will be the result of get_uri_from_db
# Rresponse will be NormalizedEntry.text
import argparse
import concurrent.futures
import glob
import json
import multiprocessing
import os
import re
import shutil
import sqlite3
import tempfile
import os
import shutil
from dataclasses import dataclass
from typing import Optional, List, Tuple, Dict, Set
import concurrent.futures
from typing import Dict, List, Optional, Set, Tuple
from urllib.parse import urlparse

import boto3
from tqdm import tqdm
import pandas as pd
from pypdf import PdfReader, PdfWriter
from tqdm import tqdm


def parse_pdf_hash(pretty_pdf_path: str) -> Optional[str]:
"""
Expand All @@ -44,7 +45,7 @@ def parse_pdf_hash(pretty_pdf_path: str) -> Optional[str]:
return urlparse(pretty_pdf_path).path.split("/")[-1]
else:
raise NotImplementedError()


def get_uri_from_db(db_path: str, pdf_hash: str) -> Optional[str]:
"""
Expand All @@ -58,6 +59,7 @@ def get_uri_from_db(db_path: str, pdf_hash: str) -> Optional[str]:
conn.close()
return result[0].strip() if result and result[0] else None


@dataclass(frozen=True)
class NormalizedEntry:
s3_path: str
Expand All @@ -70,7 +72,7 @@ class NormalizedEntry:
def from_goldkey(goldkey: str, **kwargs):
"""
Constructs a NormalizedEntry from a goldkey string.
The goldkey is expected to be of the format:
The goldkey is expected to be of the format:
<s3_path>-<page_number>
"""
s3_path = goldkey[: goldkey.rindex("-")]
Expand All @@ -81,6 +83,7 @@ def from_goldkey(goldkey: str, **kwargs):
def goldkey(self):
return f"{self.s3_path}-{self.pagenum}"


def normalize_json_entry(data: dict) -> NormalizedEntry:
"""
Normalizes a JSON entry from any of the supported formats.
Expand Down Expand Up @@ -117,6 +120,7 @@ def normalize_json_entry(data: dict) -> NormalizedEntry:
else:
raise ValueError("Unsupported JSON format")


def parse_s3_url(s3_url: str) -> Tuple[str, str]:
"""
Parses an S3 URL of the form s3://bucket/key and returns (bucket, key).
Expand All @@ -127,6 +131,7 @@ def parse_s3_url(s3_url: str) -> Tuple[str, str]:
bucket, key = s3_path.split("/", 1)
return bucket, key


def download_pdf_to_cache(s3_url: str, cache_dir: str) -> Optional[str]:
"""
Downloads the PDF from the given S3 URL into the specified cache directory.
Expand All @@ -135,11 +140,11 @@ def download_pdf_to_cache(s3_url: str, cache_dir: str) -> Optional[str]:
"""
try:
bucket, key = parse_s3_url(s3_url)
s3_client = boto3.client('s3')
s3_client = boto3.client("s3")
pdf_hash = parse_pdf_hash(s3_url)
if not pdf_hash:
# Fallback: use a sanitized version of the s3_url
pdf_hash = re.sub(r'\W+', '_', s3_url)
pdf_hash = re.sub(r"\W+", "_", s3_url)
dest_path = os.path.join(cache_dir, f"{pdf_hash}.pdf")
# Avoid re-downloading if already exists
if not os.path.exists(dest_path):
Expand All @@ -149,6 +154,7 @@ def download_pdf_to_cache(s3_url: str, cache_dir: str) -> Optional[str]:
print(f"Error downloading {s3_url}: {e}")
return None


def process_pdf_page(s3_url: str, page_number: int, combined_id: str, output_pdf_dir: str, pdf_cache: Dict[str, str]) -> Optional[str]:
"""
Extracts the specified page (1-indexed) from the cached PDF corresponding to s3_url.
Expand Down Expand Up @@ -178,6 +184,7 @@ def process_pdf_page(s3_url: str, page_number: int, combined_id: str, output_pdf
print(f"Error processing PDF page for {s3_url} page {page_number}: {e}")
return None


def process_file(file_path: str, db_path: str, output_pdf_dir: str, pdf_cache: Dict[str, str]) -> Tuple[List[dict], int]:
"""
Process a single file and return a tuple:
Expand Down Expand Up @@ -215,8 +222,7 @@ def process_file(file_path: str, db_path: str, output_pdf_dir: str, pdf_cache: D

# Apply filter: skip if response contains "resume" (any case) and an email or phone number.
response_text = normalized.text if normalized.text else ""
if (re.search(r"resume", response_text, re.IGNORECASE) and
(re.search(email_regex, response_text) or re.search(phone_regex, response_text))):
if re.search(r"resume", response_text, re.IGNORECASE) and (re.search(email_regex, response_text) or re.search(phone_regex, response_text)):
print(f"Skipping entry due to resume and contact info in response at {file_path}:{line_num}")
continue

Expand Down Expand Up @@ -254,6 +260,7 @@ def process_file(file_path: str, db_path: str, output_pdf_dir: str, pdf_cache: D
print(f"Error processing file {file_path}: {e}")
return rows, missing_count


def scan_file_for_s3_urls(file_path: str) -> Set[str]:
"""
Scans a single file and returns a set of unique S3 URLs found in the JSON entries.
Expand All @@ -276,18 +283,15 @@ def scan_file_for_s3_urls(file_path: str) -> Set[str]:
print(f"Error reading file {file_path}: {e}")
return urls


def main():
parser = argparse.ArgumentParser(
description="Generate a Parquet dataset file for HuggingFace upload."
)
parser = argparse.ArgumentParser(description="Generate a Parquet dataset file for HuggingFace upload.")
parser.add_argument(
"input_dataset",
help="Input dataset file pattern (e.g., '/data/jakep/pdfdata/openai_batch_data_v5_1_iabooks_train_done/*.json')",
)
parser.add_argument("db_path", help="Path to the SQLite database file.")
parser.add_argument(
"--output", default="output.parquet", help="Output Parquet file path."
)
parser.add_argument("--output", default="output.parquet", help="Output Parquet file path.")

args = parser.parse_args()

Expand All @@ -303,7 +307,7 @@ def main():
# Create a temporary directory for caching PDFs.
pdf_cache_dir = "/tmp/pdf_cache"
os.makedirs(pdf_cache_dir, exist_ok=True)

print(f"Caching PDFs to temporary directory: {pdf_cache_dir}")

# ---------------------------------------------------------------------
Expand All @@ -323,12 +327,8 @@ def main():
pdf_cache: Dict[str, str] = {}
print("Caching PDFs from S3...")
with concurrent.futures.ProcessPoolExecutor(max_workers=num_cpus * 8) as executor:
future_to_url = {
executor.submit(download_pdf_to_cache, s3_url, pdf_cache_dir): s3_url
for s3_url in unique_s3_urls
}
for future in tqdm(concurrent.futures.as_completed(future_to_url),
total=len(future_to_url), desc="Downloading PDFs"):
future_to_url = {executor.submit(download_pdf_to_cache, s3_url, pdf_cache_dir): s3_url for s3_url in unique_s3_urls}
for future in tqdm(concurrent.futures.as_completed(future_to_url), total=len(future_to_url), desc="Downloading PDFs"):
s3_url = future_to_url[future]
try:
local_path = future.result()
Expand All @@ -345,12 +345,8 @@ def main():
total_missing = 0
print("Processing files...")
with concurrent.futures.ProcessPoolExecutor() as executor:
futures = {
executor.submit(process_file, file_path, args.db_path, pdfs_dir, pdf_cache): file_path
for file_path in files
}
for future in tqdm(concurrent.futures.as_completed(futures),
total=len(futures), desc="Processing files"):
futures = {executor.submit(process_file, file_path, args.db_path, pdfs_dir, pdf_cache): file_path for file_path in files}
for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures), desc="Processing files"):
file_path = futures[future]
try:
rows, missing_count = future.result()
Expand Down
30 changes: 15 additions & 15 deletions olmocr/train/hf/hfhub_upload.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,21 @@
import logging
import os
import tarfile
import logging
from math import ceil
from concurrent.futures import ProcessPoolExecutor, as_completed
from tqdm import tqdm
from math import ceil

from huggingface_hub import HfApi
from tqdm import tqdm

# Configuration
pdf_dir = "pdfs" # Directory with PDF files (flat structure)
tarball_dir = "tarballs" # Directory where tar.gz files will be saved
pdf_dir = "pdfs" # Directory with PDF files (flat structure)
tarball_dir = "tarballs" # Directory where tar.gz files will be saved
os.makedirs(tarball_dir, exist_ok=True)
repo_id = "allenai/olmOCR-mix-0225" # Hugging Face dataset repo ID

# Set up logging to file
logging.basicConfig(
filename='upload.log',
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s'
)
logging.basicConfig(filename="upload.log", level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")


def process_chunk(args):
"""
Expand All @@ -27,7 +25,7 @@ def process_chunk(args):
chunk_index, chunk_files = args
tarball_name = f"pdf_chunk_{chunk_index:04d}.tar.gz"
tarball_path = os.path.join(tarball_dir, tarball_name)

try:
with tarfile.open(tarball_path, "w:gz") as tar:
for pdf_filename in chunk_files:
Expand All @@ -41,10 +39,11 @@ def process_chunk(args):
logging.error(error_msg)
return chunk_index, False, error_msg


def main():
# List all PDF files (assuming a flat directory)
try:
pdf_files = sorted([f for f in os.listdir(pdf_dir) if f.lower().endswith('.pdf')])
pdf_files = sorted([f for f in os.listdir(pdf_dir) if f.lower().endswith(".pdf")])
except Exception as e:
logging.error(f"Error listing PDFs in '{pdf_dir}': {e}")
return
Expand All @@ -61,7 +60,7 @@ def main():
# end = start + chunk_size
# chunk_files = pdf_files[start:end]
# chunks.append((idx, chunk_files))

# # Create tarballs in parallel
# results = []
# with ProcessPoolExecutor() as executor:
Expand Down Expand Up @@ -90,10 +89,11 @@ def main():
api.upload_large_folder(
folder_path=tarball_dir,
repo_id=repo_id,
#path_in_repo="pdf_tarballs",
repo_type="dataset"
# path_in_repo="pdf_tarballs",
repo_type="dataset",
)
logging.info("Successfully uploaded tarballs folder to Hugging Face Hub.")


if __name__ == "__main__":
main()
Loading

0 comments on commit 32aa359

Please sign in to comment.