Tokenizer や Embedding model の初期化

import duckdb
import torch
from lindera_py import Segmenter, Tokenizer, load_dictionary
from sentence_transformers import CrossEncoder
from transformers import AutoModel, AutoTokenizer
 
device = "cuda" if torch.cuda.is_available() else "cpu"
 
v_tokenizer = AutoTokenizer.from_pretrained(
    "pfnet/plamo-embedding-1b", trust_remote_code=True
)
v_model = AutoModel.from_pretrained("pfnet/plamo-embedding-1b", trust_remote_code=True)
v_model = v_model.to(device)
 
dictionary = load_dictionary("ipadic")
segmenter = Segmenter("normal", dictionary)
tokenizer = Tokenizer(segmenter)
 
r_model = CrossEncoder(
    "hotchpotch/japanese-bge-reranker-v2-m3-v1", max_length=512, device=device
)
 
def ja_tokens(text: str) -> str:
    return " ".join(t.text for t in tokenizer.tokenize(text))
 

Readwise から取得した JSON ファイルを DataFrame に変換

  • 下記処理を実行
    • Readwise Reader API から取得1した JSON データにおいて、HTML を markdown に変換したうえで DataFrame 化
import glob
from markdownify import markdownify
import json
import pandas as pd
 
def ja_tokens(text: str) -> str:
    return " ".join(t.text for t in tokenizer.tokenize(text))
 
def to_markdown(html:str)->str:
    return markdownify(html or "", heading_style="ATX")
 
json_files = glob.glob("articles/*.json")
records = []
for path in json_files:
    data = json.load(open(path, encoding="utf-8"))
    for rec in (data if isinstance(data, list) else [data]):
        md = to_markdown(rec.get("html_content",""))
        title = rec.get("title","")
        rec["fts_text"] = ja_tokens(title + "。 " + md)
        rec["vss_text"] = title + "。 " + md
        rec["markdown_content"] = md
        records.append(rec)
 
df = pd.DataFrame(records)
 
# 必要列だけにする
cols = ["id","title","markdown_content","fts_text","vss_text"]
df = df[cols]

DuckDB に格納

  • 下記処理を実行
    • Embedding (vector) を計算して格納
    • FTS (Full-text search) のインデックスを作成
con = duckdb.connect("article_search.duckdb")
con.install_extension("vss")
con.load_extension("vss")
con.install_extension("fts")
con.load_extension("fts")
 
 
# 1. 先に空テーブルを作成
con.sql("""
CREATE TABLE IF NOT EXISTS articles (
    id VARCHAR PRIMARY KEY,
    title VARCHAR,
    markdown_content VARCHAR,
    fts_text VARCHAR,          -- 分かち書き済み全文
    embedding FLOAT[2048]      -- sentence‐transformer 2048 次元
);
""")
 
# 既存データ削除して入れ直すなら
con.sql("DELETE FROM articles;")
 
# 2. Embedding 計算
for i, rec in df.iterrows():
    # 既に同じ id が DB にあればスキップ
    already = con.sql(
        "SELECT COUNT(*) FROM articles WHERE id = ?", params=[rec["id"]]
    ).fetchone()[0]
    if already:
        print(f"SKIP: {rec['id']} (already in DB)")
        continue
 
    try:
        with torch.inference_mode():
            emb = v_model.encode_document([rec["vss_text"]], v_tokenizer)[0]
        con.execute(
            "INSERT INTO articles VALUES (?, ?, ?, ?, ?)",
            [
                rec["id"],
                rec["title"],
                rec["markdown_content"],
                rec["fts_text"],
                emb.cpu().squeeze().numpy().tolist(),
            ],
        )
        print(f"INSERT: {rec['id']}")
    except Exception as e:
        print(f"ERROR: {rec['id']} - {e}")
 
# 3. FTS インデックス
# 既存のインデックスがあれば削除
try:
    con.sql("PRAGMA drop_fts_index('articles');")
except Exception as e:
    print("FTSインデックス削除時の例外:", e)
    
con.sql("""
PRAGMA create_fts_index(
    'articles',     -- テーブル
    'id',           -- PRIMARY KEY
    'fts_text',
    stemmer='none', stopwords='none', ignore='', lower=false, strip_accents=false
);
""")

クエリ関数を用意

def fts_search(conn, query, k=5):
    q = ja_tokens(query)
    return conn.sql(f"""
        SELECT id, title,
               fts_main_articles.match_bm25(id, '{q}') AS score
        FROM articles
        WHERE score IS NOT NULL
        ORDER BY score DESC
        LIMIT {k}
    """).fetchdf()
 
def vss_search(conn, query, k=5):
    with torch.inference_mode():
        q_emb = v_model.encode_query(query, v_tokenizer)
    return conn.sql(
        f"""
        SELECT id, title,
               array_cosine_distance(embedding, ?::FLOAT[2048]) AS dist
        FROM articles
        ORDER BY dist ASC
        LIMIT {k}
        """,
        params=[q_emb.cpu().squeeze().numpy().tolist()],
    ).fetchdf()
 
from sentence_transformers import CrossEncoder
reranker = CrossEncoder(
    "hotchpotch/japanese-bge-reranker-v2-m3-v1", device=device, max_length=512
)
 
def hybrid_search(conn, query, k_fts=10, k_vss=10, top_n=5):
    fts_df = fts_search(conn, query, k_fts)
    vss_df = vss_search(conn, query, k_vss)
    pool = pd.concat([fts_df, vss_df]).drop_duplicates("id")
    pairs = [(query, txt) for txt in pool["title"]]
    scores = reranker.predict(pairs)
    pool["score"] = scores
    return pool.sort_values("score", ascending=False).head(top_n)[["id","title","score"]]

クエリのテスト

print("全文検索 ↓")
display(fts_search(con, "型システム", k=5))
 
print("ベクトル検索 ↓")
display(vss_search(con, "型システム", k=5))
 
print("ハイブリッド検索 ↓")
display(hybrid_search(con, "型システム", top_n=5))

Footnotes

  1. スクリプトはこちら