Spaces:
Runtime error
Runtime error
| import os | |
| os.environ["KERAS_BACKEND"] = "torch" | |
| from keras import models, utils | |
| import gradio as gr | |
| import numpy as np | |
| import pandas as pd | |
| from huggingface_hub import hf_hub_download | |
| def echo_sudoku(sudoku, model_name): | |
| model = models.load_model(hf_hub_download( | |
| repo_id="Ritvik19/SudokuNet", | |
| filename=f"{model_name}.keras", | |
| )) | |
| puzzles = sudoku.copy().values.reshape(1, 9, 9) | |
| for _ in range((puzzles == 0).sum((1, 2)).max()): | |
| model_preds = model.predict( | |
| utils.to_categorical(puzzles, num_classes=10), verbose=0 | |
| ) | |
| preds = np.zeros((puzzles.shape[0], 81, 9)) | |
| for i in range(9): | |
| for j in range(9): | |
| preds[:, i * 9 + j] = model_preds[f"position_{i+1}_{j+1}"] | |
| probs = preds.max(2) | |
| values = preds.argmax(2) + 1 | |
| zeros = (puzzles == 0).reshape((puzzles.shape[0], 81)) | |
| for grid, prob, value, zero in zip(puzzles, probs, values, zeros): | |
| if any(zero): | |
| where = np.where(zero)[0] | |
| confidence_position = where[prob[zero].argmax()] | |
| confidence_value = value[confidence_position] | |
| grid.flat[confidence_position] = confidence_value | |
| return puzzles[0] | |
| model_types = ['ffn', 'cnn'] | |
| model_sizes = ['64x2', '64x4', '128x2', '128x4'] | |
| model_names = [f"{model_type}__{model_size}" for model_type in model_types for model_size in model_sizes] | |
| DEFAULT_PUZZLE = """ | |
| 0 0 4 3 0 0 2 0 9 | |
| 0 0 5 0 0 9 0 0 1 | |
| 0 7 0 0 6 0 0 4 3 | |
| 0 0 6 0 0 2 0 8 7 | |
| 1 9 0 0 0 7 4 0 0 | |
| 0 5 0 0 8 3 0 0 0 | |
| 6 0 0 0 0 0 1 0 5 | |
| 0 0 3 5 0 8 6 9 0 | |
| 0 4 2 9 1 0 3 0 0 | |
| """.strip() | |
| DEFAULT_PUZZLE = np.array([int(digit) for digit in DEFAULT_PUZZLE.split()]).reshape(9, 9) | |
| interface = gr.Interface( | |
| fn=echo_sudoku, | |
| inputs=[ | |
| gr.Dataframe(label="Input Sudoku Puzzle", datatype="number", row_count=9, col_count=9, value=DEFAULT_PUZZLE), | |
| gr.Dropdown(label="Select Model", choices=model_names, value="cnn__64x2") | |
| ], | |
| outputs=gr.Dataframe(label="Input Sudoku Puzzle", datatype="number", row_count=9, col_count=9), | |
| title="Sudoku Solver", | |
| description='A demo app for <a href="https://ritvik19.github.io/sudoku-net" target="_blank">SudokuNet</a>' | |
| ) | |
| # Run the app | |
| if __name__ == "__main__": | |
| interface.launch(debug=True) |