Skip to content

Commit

Permalink
More pipeline code
Browse files Browse the repository at this point in the history
  • Loading branch information
jakep-allenai committed Oct 11, 2024
1 parent 10b7a58 commit 53fdb61
Showing 1 changed file with 108 additions and 19 deletions.
127 changes: 108 additions & 19 deletions pdelfin/assemblepipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,19 @@
import sqlite3
import json
import argparse
import glob
import tempfile
import posixpath

from pypdf import PdfReader
from tqdm import tqdm
from typing import Optional
from urllib.parse import urlparse
from concurrent.futures import ProcessPoolExecutor, as_completed

# Global s3 client for the whole script, feel free to adjust params if you need it
s3 = boto3.client('s3')

class DatabaseManager:
def __init__(self, s3_workspace: str):
cache_key = hashlib.sha256(s3_workspace.strip().lower().encode('utf-8')).hexdigest()
Expand Down Expand Up @@ -81,17 +91,35 @@ def update_processed_file(self, s3_path, etag):
""", (s3_path, etag))
self.conn.commit()

def pdf_exists(self, s3_path: str) -> bool:
self.cursor.execute("SELECT 1 FROM pdfs WHERE s3_path = ?", (s3_path,))
return self.cursor.fetchone() is not None

def add_pdf(self, s3_path: str, num_pages: int, status: str = 'pending') -> None:
try:
self.cursor.execute("""
INSERT INTO pdfs (s3_path, num_pages, status)
VALUES (?, ?, ?)
""", (s3_path, num_pages, status))
self.conn.commit()
except sqlite3.IntegrityError:
print(f"PDF with s3_path '{s3_path}' already exists.")

def get_pdf_status(self, s3_path: str) -> Optional[str]:
self.cursor.execute("SELECT status FROM pdfs WHERE s3_path = ?", (s3_path,))
result = self.cursor.fetchone()
return result[0] if result else None

def close(self):
self.conn.close()

def build_index(s3_path):
db_manager = DatabaseManager(s3_path)

s3 = boto3.client('s3')
bucket, prefix = parse_s3_path(s3_path)

# List all .json and .jsonl files under s3_path with their ETags
files = list_s3_files(s3, bucket, prefix)
files = expand_s3_glob(s3_path)

if not files:
print("No .json or .jsonl files found in the specified S3 path.")
Expand Down Expand Up @@ -131,18 +159,24 @@ def parse_s3_path(s3_path):
bucket, _, prefix = path.partition('/')
return bucket, prefix

def list_s3_files(s3, bucket, prefix):
def expand_s3_glob(s3_glob: str) -> dict[str, str]:
parsed = urlparse(s3_glob)
bucket_name = parsed.netloc
prefix = os.path.dirname(parsed.path.lstrip('/')).rstrip('/') + "/"
pattern = os.path.basename(parsed.path)


paginator = s3.get_paginator('list_objects_v2')
page_iterator = paginator.paginate(Bucket=bucket, Prefix=prefix)
files = {}
page_iterator = paginator.paginate(Bucket=bucket_name, Prefix=prefix)

matched_files = {}
for page in page_iterator:
contents = page.get('Contents', [])
for obj in contents:
for obj in page.get('Contents', []):
key = obj['Key']
if key.endswith('.json') or key.endswith('.jsonl'):
# Retrieve ETag for each file
files[key] = obj['ETag'].strip('"')
return files
if glob.fnmatch.fnmatch(key, posixpath.join(prefix, pattern)):
matched_files[f"s3://{bucket_name}/{key}"] = obj['ETag'].strip('"')

return matched_files

def process_file(bucket, key, etag):
s3 = boto3.client('s3') # Initialize s3 client in the worker process
Expand Down Expand Up @@ -177,21 +211,76 @@ def process_jsonl_content(content, s3_path):
start_index = end_index
return index_entries

def get_s3_bytes(s3_path: str, start_index: Optional[int] = None, end_index: Optional[int] = None) -> bytes:
bucket, key = parse_s3_path(s3_path)

# Build the range header if start_index and/or end_index are specified
range_header = None
if start_index is not None or end_index is not None:
range_value = f"bytes={start_index or 0}-"
if end_index is not None:
range_value += str(end_index)
range_header = {'Range': range_value}

if range_header:
obj = s3.get_object(Bucket=bucket, Key=key, Range=range_header['Range'])
else:
obj = s3.get_object(Bucket=bucket, Key=key)

return obj['Body'].read()

def get_pdf_num_pages(s3_path: str) -> Optional[int]:
try:
with tempfile.NamedTemporaryFile("wb+", suffix=".pdf") as tf:
tf.write(get_s3_bytes(s3_path))
tf.flush()

reader = PdfReader(tf.name)
return reader.get_num_pages()
except Exception as ex:
print(f"Warning, could not add {s3_path} due to {ex}")

return None


if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Manager for running millions of PDFs through a batch inference pipeline')
parser.add_argument('workspace', help='The S3 path where work will be done e.g., s3://bucket/prefix/)')
parser.add_argument('--pdf_glob_path', help='Glob path to PDFs (local or s3)', default=None)
parser.add_argument('--pdfs', help='Glob path to PDFs (local or s3)', default=None)
parser.add_argument('--file_size_limit', type=int, default=250, help='Max file size in MB')
args = parser.parse_args()

db = DatabaseManager(args.workspace)
print(f"Loaded db at {db.db_path}")
print(f"Current round is {db.get_current_round()}")
print(f"Current round is {db.get_current_round()}\n")

# One shared executor to rule them all
executor = ProcessPoolExecutor()

# If you have new PDFs, add them to the list
if args.pdfs:
assert args.pdfs.startswith("s3://"), "PDFs must live on s3"

print(f"Querying all PDFs at {args.pdfs}")

all_pdfs = expand_s3_glob(args.pdfs)
print(f"Found {len(all_pdfs)} total pdf paths")

all_pdfs = [pdf for pdf in all_pdfs if not db.pdf_exists(pdf)]
print(f"Need to import {len(all_pdfs)} total new pdf paths")

future_to_path = {executor.submit(get_pdf_num_pages, s3_path): s3_path for s3_path in all_pdfs}
for future in tqdm(as_completed(future_to_path), total=len(future_to_path)):
s3_path = future_to_path[future]
if future.result() and not db.pdf_exists(s3_path):
db.add_pdf(s3_path, future.result(), "pending")

print("\n")


if args.pdf_glob_path:
# Add new pdfs to be processed if they don't exist in the database
# TODO
pass
# Now build an index of all the pages that were processed within the workspace so far
build_index(f"{args.workspace}/*.jsonl")

# Step one, build an index of all the pages that were processed
build_index(args.workspace)
# Now, for each pending book, find all pages which still need to be processed
# and add them to the next round's batch inference jobs

0 comments on commit 53fdb61

Please sign in to comment.