Custom Dataset Generator¶
This guide explains how to create custom datasets for training Transformer models on mathematical tasks using the CALT library. The library provides a flexible framework for generating various types of mathematical problems, from polynomial operations to integer arithmetic.
Overview¶
The CALT library offers a modular approach to dataset generation, allowing you to create custom problem generators for any mathematical task. After following the Quick Start guide, you can extend your experiments by creating custom datasets. The process involves several components:
- Custom Problem Generator: Defines how to create problem-solution pairs
- Custom Statistics Calculator (optional): Analyzes the generated dataset
1. Custom Problem Generator¶
To create your own problem generator, you need to define a class with a __call__
method according to your specific mathematical task.
Basic Structure¶
import sage.misc.randstate as randstate
class CustomProblemGenerator:
def __init__(self, sampler, **kwargs):
self.sampler = sampler
# Add other parameters as needed
def __call__(self, seed: int):
# Set random seed for SageMath's random state
randstate.set_random_seed(seed)
# Your problem generation logic here
# ...
# Return a tuple (problem, solution)
return problem, solution
Key Points¶
- Method Input: The
__call__
method takes aseed
parameter for reproducibility - Reproducibility: Always use
randstate.set_random_seed(seed)
at the beginning - Sampler: Use the provided sampler for generating mathematical objects (required for polynomial tasks, optional for arithmetic tasks)
- Method Output: Return a tuple
(problem, solution)
Note: The problem
and solution
can be of various types depending on your mathematical task:
- Single values:
int
,float
,MPolynomial_libsingular
- Lists:
list[int]
,list[float]
, orlist[MPolynomial_libsingular]
- Matrices: Nested lists, e.g.,
list[list[int]]
orlist[list[MPolynomial_libsingular]]
For concrete examples of problem generators, see Problem Generator Examples.
2. Custom Statistics Calculator (optional)¶
Create a custom statistics calculator to analyze your generated data. The calculator should inherit from BaseStatisticsCalculator
and implement the __call__
method.
Basic Structure¶
from calt.dataset_generator.sagemath import BaseStatisticsCalculator
class CustomStatisticsCalculator(BaseStatisticsCalculator):
def __call__(
self, problem: Any, solution: Any
) -> dict[str, dict[str, int | float]]:
"""
Calculate statistics for a single sample.
Args:
problem: The problem data
solution: The solution data
Returns:
Dictionary with keys "problem" and "solution", each mapping to a sub-dictionary
containing descriptive statistics.
"""
return {
"problem": self.calculate_problem_stats(problem),
"solution": self.calculate_solution_stats(solution)
}
def calculate_problem_stats(self, problem) -> dict[str, int | float]:
"""Calculate statistics for the problem data."""
# Your problem statistics logic here
pass
def calculate_solution_stats(self, solution) -> dict[str, int | float]:
"""Calculate statistics for the solution data."""
# Your solution statistics logic here
pass
Key Points¶
- Inheritance: Must inherit from
BaseStatisticsCalculator
- Method Input: The
__call__
method takesproblem
andsolution
. - Method Output: Return a dictionary
dict[str, dict[str, int |float]]
with keys "problem" and "solution".
For concrete examples of statistics calculators, see Statistics Calculator Examples.