chess / app.py
christopher's picture
Update app.py
a8850d7 verified
raw
history blame
2.63 kB
from fastapi import FastAPI, Request
from fastapi.responses import HTMLResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
import chess
from datasets import load_dataset
import pickle
import time
from pyroaring import BitMap
def board_to_tokens(board):
return [(board.piece_at(sq).symbol(), chess.square_name(sq)) for sq in chess.SQUARES if board.piece_at(sq)]
def get_puzzle_positions(fen, moves_uci):
positions = []
board = chess.Board(fen)
board.push_uci(moves_uci.split()[0])
positions.append(board.copy())
for move_uci in moves_uci.split()[1:]:
board.push_uci(move_uci)
positions.append(board.copy())
return positions
def load_index(path='chess_index.pkl'):
with open(path, 'rb') as f: data = pickle.load(f)
return data['index'], data['metadata']
def query_positions(index, metadata, query_tokens):
result = index[query_tokens[0]].copy() if query_tokens[0] in index else BitMap()
for token in query_tokens[1:]:
if token in index: result &= index[token]
else: return BitMap()
return [(pos_id, metadata[pos_id]) for pos_id in result]
dset = load_dataset("Lichess/chess-puzzles", split="train")
index, metadata = load_index()
app = FastAPI()
app.mount("/static", StaticFiles(directory="static"), name="static")
templates = Jinja2Templates(directory="templates")
@app.get("/")
def read_root(request: Request):
return templates.TemplateResponse("index.html", {"request": request})
@app.post("/search")
async def search(data: dict):
start = time.time()
board = chess.Board(data['fen'])
query_tokens = board_to_tokens(board)
matches = query_positions(index, metadata, query_tokens)
seen_puzzles = {}
for pos_id, (puzzle_row, move_idx) in matches:
if puzzle_row not in seen_puzzles:
seen_puzzles[puzzle_row] = (pos_id, move_idx)
results = []
for puzzle_row, (pos_id, move_idx) in seen_puzzles.items():
row = dset[puzzle_row]
positions = get_puzzle_positions(row['FEN'], row['Moves'])
matched_board = positions[move_idx]
results.append({
"PuzzleId": row['PuzzleId'],
"FEN": matched_board.fen(),
"Moves": row['Moves'],
"Rating": row['Rating'],
"Popularity": row['Popularity'],
"Themes": row['Themes'],
"MatchedMove": move_idx
})
elapsed_ms = (time.time() - start) * 1000
return {"count": len(results), "results": results, "time_ms": elapsed_ms}