zensols.lmtask package

Submodules

zensols.lmtask.app module

Large langauage model experimentation.

class zensols.lmtask.app.Application(config_factory, task_factory)[source]

Bases: object

Large langauage model experimentation.

__init__(config_factory, task_factory)
config_factory: ConfigFactory

Used to create training resources.

dataset_sample(max_sample=1)[source]

Print sample(s) of the configured (--config) dataset.

Parameters:

max_sample (int) – the number of sample to print

instruct(task_name, instruction, role=None, output_format=None)[source]

Generate text by inferencing with the model.

Parameters:
  • task_name (str) – the task that generates the result

  • instruction (str) – added to the prompt to instruction the model

  • role (str) – the role the model takes

show_task(task_name=None)[source]

Print the configuration of a task if --name is given, otherise a list of available tasks.

Parameters:

task_name (str) – the task that creates the prompt and parses the result

show_trainer(long_output=False)[source]

Print configuration and dataset stats of the configured (--config) trainer.

Parameters:

long_output (bool) – verbosity

stream(task_name, prompt)[source]

Stream generated text from the model.

Parameters:
  • task_name (str) – the task that generates the result

  • prompt (str) – the prompt text as input to the model

task_factory: TaskFactory

Create tasks used to fullfill CLI requests.

train()[source]

Train a new model on a configured (--config) dataset.

class zensols.lmtask.app.PrototypeApplication(config_factory, app, prompt='Once upon a time, in a galaxy, far far away,')[source]

Bases: object

Used by the Python REPL for prototyping.

CLI_META = {'is_usage_visible': False}
__init__(config_factory, app, prompt='Once upon a time, in a galaxy, far far away,')
app: Application
config_factory: ConfigFactory
prompt: str = 'Once upon a time, in a galaxy, far far away,'
proto(run=0)[source]

zensols.lmtask.cli module

Command line entry point to the application.

class zensols.lmtask.cli.ApplicationFactory(*args, **kwargs)[source]

Bases: ApplicationFactory

__init__(*args, **kwargs)[source]
classmethod get_application()[source]

Get a text generator instance.

Return type:

TextGenerator

classmethod get_task_factory()[source]

Get the factory that creates tasks.

Return type:

TaskFactory

zensols.lmtask.cli.main(args=['/Users/landes/opt/lib/pixi/envs/zensols_relpo/bin/sphinx-build', '-M', 'html', '/Users/landes/view/nlp/lmtask/target/doc/stage', '/Users/landes/view/nlp/lmtask/target/doc/build'], **kwargs)[source]
Return type:

ActionResult

zensols.lmtask.dataset module

An implementation of a dataset generator task.TaskDatasetFactory.

class zensols.lmtask.dataset.LoadedTaskDatasetFactory(task, text_field='text', messages_field='messages', eval_field='text', source=None, load_args=<factory>, pre_process=None, post_process=None)[source]

Bases: TaskDatasetFactory

A utility class meant to be created from an application configuration. This class creates a dataframe used by Trainer and optionally does post processing (i.e. filtering and mapping).

__init__(task, text_field='text', messages_field='messages', eval_field='text', source=None, load_args=<factory>, pre_process=None, post_process=None)
static clear_generator_cache()[source]
load_args: Dict[str, Any]

Additional arguments given to datasets.load_dataset().

post_process: Union[str, Callable] = None

Code to call after the dataset is created and the task has applied any template.

See:

pre_process

pre_process: Union[str, Callable] = None

Code to call after the dataset is created but before the task applies any template. If this is a string exec() is used to evaluate it. Otherwise it is treated as a callable where the old dataset is the input and the returned value is the replaced dataset.

source: Union[str, Path, Stash, DataFrame, Dataset] = None

Used as the source data in the created dataset.

zensols.lmtask.generate module

Facade to HuggingFace text generation.

class zensols.lmtask.generate.CachingGenerator(_delegate, _stash, _hasher=<factory>)[source]

Bases: TextGenerator

A generator that caches response using a hash of the model input as a key.

__init__(_delegate, _stash, _hasher=<factory>)
clear()[source]

Clear any model state.

class zensols.lmtask.generate.ConstantTextGenerator(config_factory, response, post_init_source=None)[source]

Bases: TextGenerator

A generator that responses with response with every generation call for the purpose of debugging.

__init__(config_factory, response, post_init_source=None)
config_factory: ConfigFactory

Used to set optional mock attributes in post_init_source.

post_init_source: str = None

Python source code to run in the initializer.

response: str

The fixed response for each generate() call or the prompt if None.

class zensols.lmtask.generate.GenerateTask(name, description, request_class, response_class, generator, resource, train_add_eos=False)[source]

Bases: Task

Uses a TextGenerator (generator) to generate a response.

__init__(name, description, request_class, response_class, generator, resource, train_add_eos=False)
clear()[source]

Clear any generator state or cache.

generator: TextGenerator

A client facade of a chat or instruct-based large language model.

resource: GeneratorResource

The class that creates resources such as the tokenizer and model. This should be the base model resource so training tasks do not depend on the model they will eventually create.

This is also used by InstructTask for its chat template.

train_add_eos: bool = False

Whether to add the end of sentence token to the output when mapping the dataset for training. Newer versions of the trl.SFTTrainer class add (and force) this already.

class zensols.lmtask.generate.GeneratorOutput(model_output, parsed)[source]

Bases: Dictable

Container instances of model output from TextGenerator.

__init__(model_output, parsed)
model_output: str

The unmodified raw model output.

parsed: Tuple[str, ...]

The processed model output with special tokens stripped.

class zensols.lmtask.generate.GeneratorResource(name, model_id, model_class=<class 'transformers.models.auto.modeling_auto.AutoModelForCausalLM'>, tokenizer_class=<class 'transformers.models.auto.tokenization_auto.AutoTokenizer'>, peft_model_id=None, peft_model_class=<class 'peft.auto.AutoPeftModelForCausalLM'>, model_desc=None, system_role_name='system', model_args=<factory>)[source]

Bases: Dictable

A client facade of a chat-based large language model.

__init__(name, model_id, model_class=<class 'transformers.models.auto.modeling_auto.AutoModelForCausalLM'>, tokenizer_class=<class 'transformers.models.auto.tokenization_auto.AutoTokenizer'>, peft_model_id=None, peft_model_class=<class 'peft.auto.AutoPeftModelForCausalLM'>, model_desc=None, system_role_name='system', model_args=<factory>)
clear(include_cuda=True)[source]

Clear the cached tokenizer, model and optionally CUDA.

configure_model(model)[source]

Make any necessary updates programatically.

configure_tokenizer(tokenizer)[source]

Make any necessary updates programatically (i.e. set special tokens).

classmethod get_model_path(model_id, parent=None)[source]

Create a normalized file name from a HF model ID string useful for creating checkpoint directory names.

Parameters:
  • model_id (str) – the model ID (i.e. meta-llama/Llama-3.1-8B)

  • parent (Path) – the base directory used in the return value if given

Return type:

Path

property model: PreTrainedModel

The LLM.

model_args: Dict[str, Any]

The arguments given to the HF model from_pretrained method.

model_class

The class used to create the model with from_pretrained().

alias of AutoModelForCausalLM

model_desc: str = None

A human readable description of the model this resource contains.

property model_file_name: str

A normalized file name friendly string based on model_desc.

model_id: Union[str, Path]

The HF model ID or path to the model.

name: str

The section of this configured instance in the application config.

peft_model_class

The class used to create the model with from_pretrained().

alias of AutoPeftModelForCausalLM

peft_model_id: Union[str, Path] = None

The HF model ID or path to the Peft model or None if there is none.

system_role_name: str = 'system'

The default name of the system’s role.

property tokenizer: PreTrainedTokenizer

The model’s tokenzier.

tokenizer_class

The class used to create the tokenizer with :meth:`~transformers.AutoTokenizer.from_pretrained.

alias of AutoTokenizer

class zensols.lmtask.generate.ModelTextGenerator(resource, tokenize_params=<factory>, tokenize_decode_params=<factory>, generate_params=<factory>)[source]

Bases: TextGenerator

An implementation that uses HuggingFace framework classes from GeneratorResource to answer queries.

__init__(resource, tokenize_params=<factory>, tokenize_decode_params=<factory>, generate_params=<factory>)
clear()[source]

Clear any model state.

generate_params: Dict[str, Any]

Parameters given to the model’s inference method for each prompt.

resource: GeneratorResource

The class that creates resources such as the tokenizer and model.

stream(prompt, writer=<_io.TextIOWrapper name='<stdout>' mode='w' encoding='utf-8'>, width=80)[source]

Stream the model’s output from a prompt input.

Parameters:
  • prompt (str) – the input to give to the model

  • writer (TextIOBase) – the data sink

  • width (int) – the maximum width of each line’s streamed text; if None, no modification will be done on the text output

tokenize_decode_params: Dict[str, Any]

Parameters to add or override in the model tokenize call.

tokenize_params: Dict[str, Any]

Parameters to add or override in the model tokenize call.

class zensols.lmtask.generate.ReplaceTextGenerator(resource, tokenize_params=<factory>, tokenize_decode_params=<factory>, generate_params=<factory>, replacements=())[source]

Bases: ModelTextGenerator

A text generator that generates response by replacing regular expressions. This is helpful for removing special tokens.

__init__(resource, tokenize_params=<factory>, tokenize_decode_params=<factory>, generate_params=<factory>, replacements=())
replacements: Tuple[Tuple[Union[str, Pattern], str], ...] = ()

The a tuple (<regular expression>, <replacement>) to replace in the parsed output from the model. String patters are compiled with re.compile().

class zensols.lmtask.generate.TextGenerator[source]

Bases: Dictable

A client facade of a chat-based large language model.

__init__()
clear()[source]

Clear any model state.

generate(prompt)[source]

Generate a textual response (usually from a large langauge model).

Return type:

GeneratorOutput

zensols.lmtask.hf module

HuggingFace trainer wrapper.

class zensols.lmtask.hf.HFTrainerResource(model_args=None, cache=True, generator_resource=None, peft_config=None)[source]

Bases: TrainerResource

Uses HuggingFaceTrainer for training the model.

__init__(model_args=None, cache=True, generator_resource=None, peft_config=None)
generator_resource: GeneratorResource = None

The resource used to the source checkpoint.

peft_config: LoraConfig = None

The Peft low rank adapters configuration.

class zensols.lmtask.hf.HuggingFaceTrainer(config, resource, train_params, eval_params, train_source, eval_source, peft_output_dir, merged_output_dir)[source]

Bases: Trainer

The HuggingFace trainer.

__init__(config, resource, train_params, eval_params, train_source, eval_source, peft_output_dir, merged_output_dir)

zensols.lmtask.instruct module

Task implementations.

class zensols.lmtask.instruct.InstructModelTextGenerator(resource, tokenize_params=<factory>, tokenize_decode_params=<factory>, generate_params=<factory>, replacements=())[source]

Bases: ReplaceTextGenerator

A generator that uses instruct based models for inference.

__init__(resource, tokenize_params=<factory>, tokenize_decode_params=<factory>, generate_params=<factory>, replacements=())
class zensols.lmtask.instruct.InstructTask(name, description, request_class, response_class, generator, resource, train_add_eos=False, role='You are a helpful assistant.', train_template='### Question: {{ instruction }}\\n### Answer: {{ output }}', inference_template='{{request.instruction}}', chat_template_args=<factory>, apply_chat_template=True, train_apply_chat_template=False)[source]

Bases: GenerateTask

A task that is resolved using instructions given to the language model.

Important: If InstructTaskRequest.model_input is non-None that value is used verbatim and InstructTaskRequest.instruction is ignored.

__init__(name, description, request_class, response_class, generator, resource, train_add_eos=False, role='You are a helpful assistant.', train_template='### Question: {{ instruction }}\\n### Answer: {{ output }}', inference_template='{{request.instruction}}', chat_template_args=<factory>, apply_chat_template=True, train_apply_chat_template=False)
apply_chat_template: bool = True

Whether format the prompt into one that conforms to the model’s instruct syntax.

chat_template_args: Dict[str, Any]

Arguments given to apply_chat_template.

inference_template: Union[str, Path] = '{{request.instruction}}'

The instructions given to generator.

role: str = 'You are a helpful assistant.'

The role of the chat dialogue.

train_apply_chat_template: bool = False

Like apply_chat_template, but whether to apply during training. If this is False, a conversational messages with dictionary list is used instead.

train_template: Union[str, Path] = '### Question: {{ instruction }}\n### Answer: {{ output }}'

Used to create format the datasets training text generator.

write(depth=0, writer=<_io.TextIOWrapper name='<stdout>' mode='w' encoding='utf-8'>)[source]

Write this instance as either a Writable or as a Dictable. If class attribute _DICTABLE_WRITABLE_DESCENDANTS is set as True, then use the write() method on children instead of writing the generated dictionary. Otherwise, write this instance by first creating a dict recursively using asdict(), then formatting the output.

If the attribute _DICTABLE_WRITE_EXCLUDES is set, those attributes are removed from what is written in the write() method.

Note that this attribute will need to be set in all descendants in the instance hierarchy since writing the object instance graph is done recursively.

Parameters:
  • depth (int) – the starting indentation depth

  • writer (TextIOBase) – the writer to dump the content of this writable

class zensols.lmtask.instruct.InstructTaskRequest(model_input=None, instruction=None)[source]

Bases: TaskRequest

A request that has a query portion to be added to the compiled prompt.

__init__(model_input=None, instruction=None)
instruction: Any = None

The instruction given to the model to complete the task.

write(depth=0, writer=<_io.TextIOWrapper name='<stdout>' mode='w' encoding='utf-8'>, include_instruction=True)[source]

Write this instance as either a Writable or as a Dictable. If class attribute _DICTABLE_WRITABLE_DESCENDANTS is set as True, then use the write() method on children instead of writing the generated dictionary. Otherwise, write this instance by first creating a dict recursively using asdict(), then formatting the output.

If the attribute _DICTABLE_WRITE_EXCLUDES is set, those attributes are removed from what is written in the write() method.

Note that this attribute will need to be set in all descendants in the instance hierarchy since writing the object instance graph is done recursively.

Parameters:
  • depth (int) – the starting indentation depth

  • writer (TextIOBase) – the writer to dump the content of this writable

class zensols.lmtask.instruct.NShotTaskRequest(model_input=None, instruction=None, examples=None)[source]

Bases: InstructTaskRequest

A request that adds training examples to the prompt.

__init__(model_input=None, instruction=None, examples=None)
examples: Tuple[Any, ...] = None

The examples given for N-shot learning.

zensols.lmtask.llama module

Interactive chat interfaces, which are superset to chat templates.

class zensols.lmtask.llama.LlamaGeneratorResource(name, model_id, model_class=<class 'transformers.models.auto.modeling_auto.AutoModelForCausalLM'>, tokenizer_class=<class 'transformers.models.auto.tokenization_auto.AutoTokenizer'>, peft_model_id=None, peft_model_class=<class 'peft.auto.AutoPeftModelForCausalLM'>, model_desc=None, system_role_name='system', model_args=<factory>)[source]

Bases: GeneratorResource

There are 4 different roles that are supported by Llama text models:

  • system: Sets the context in which to interact with the AI model. It

    typically includes rules, guidelines, or necessary information that help the model respond effectively.

  • user: Represents the human interacting with the model. It includes the

    inputs, commands, and questions to the model.

  • ipython: A new role introduced in Llama 3.1. Semantically, this role

    means “tool”. This role is used to mark messages with the output of a tool call when sent back to the model from the executor.

  • assistant: Represents the response generated by the AI model based on

    the context provided in the system, ipython and user prompts.

__init__(name, model_id, model_class=<class 'transformers.models.auto.modeling_auto.AutoModelForCausalLM'>, tokenizer_class=<class 'transformers.models.auto.tokenization_auto.AutoTokenizer'>, peft_model_id=None, peft_model_class=<class 'peft.auto.AutoPeftModelForCausalLM'>, model_desc=None, system_role_name='system', model_args=<factory>)
configure_model(model)[source]

Make any necessary updates programatically.

configure_tokenizer(tokenizer)[source]

Make any necessary updates programatically (i.e. set special tokens).

zensols.lmtask.task module

Task implementations.

class zensols.lmtask.task.JSONTaskResponse(request, model_output_raw, model_output, robust_json=True)[source]

Bases: TaskResponse

A task that parses the responses as JSON. The JSON is parsed as much as possible and does not raise errors when the json is incomplete.

__init__(request, model_output_raw, model_output, robust_json=True)
property any_failures: bool

Whether any failures were created during JSON parsing.

property model_output_json: Failure | str

The response attribute parsed as JSON.

Raises:

json.decoder.JSONDecodeError – if the JSON failed to parse

See:

obj:robust_json

robust_json: bool = True

Whether to return Failure from model_output_json instead of raising from parse failures.

write(depth=0, writer=<_io.TextIOWrapper name='<stdout>' mode='w' encoding='utf-8'>, include_request=False, include_model_output=False, include_json=True)[source]

Write this instance as either a Writable or as a Dictable. If class attribute _DICTABLE_WRITABLE_DESCENDANTS is set as True, then use the write() method on children instead of writing the generated dictionary. Otherwise, write this instance by first creating a dict recursively using asdict(), then formatting the output.

If the attribute _DICTABLE_WRITE_EXCLUDES is set, those attributes are removed from what is written in the write() method.

Note that this attribute will need to be set in all descendants in the instance hierarchy since writing the object instance graph is done recursively.

Parameters:
  • depth (int) – the starting indentation depth

  • writer (TextIOBase) – the writer to dump the content of this writable

class zensols.lmtask.task.Task(name, description, request_class, response_class)[source]

Bases: Dictable

Subclasses turn a prompt and query into a response from an LLM.

__init__(name, description, request_class, response_class)
clear()[source]

Clear any generator state or cache.

description: str

A description of the task.

name: str

The name of the task.

prepare_dataset(ds, factory)[source]

Massage the any data for training necessary to train this task. This might involve apply templates and/or adding terminating tokens.

Return type:

Dataset

prepare_request(request)[source]

Return a request with the contents populated with a formatted prompt.

Return type:

TaskRequest

process(request)[source]

Invoke the generator to query the LLM, then return a JSON formatted data.

Parameters:

query – a query that is phrased with the assumption that JSON is given as a response

Return type:

TaskResponse

request_class: Type[TaskRequest]

The response data.

response_class: Type[TaskResponse]

The response data.

class zensols.lmtask.task.TaskDatasetFactory(task, text_field='text', messages_field='messages', eval_field='text')[source]

Bases: Dictable

Subclasses create a dataframes used by Trainer and optionally does post processing (i.e. filtering and mapping).

__init__(task, text_field='text', messages_field='messages', eval_field='text')
create()[source]

Create a new dataset based on source.

Return type:

Dataset

Returns:

the new dataset after modification by post_process

eval_field: str = 'text'

The field used for comparison with the the evaluation dataset.

messages_field: str = 'messages'

The target conversational field used by the trainer.

task: Task

The task that helps format text in datasets.

text_field: str = 'text'

The target text field used by the trainer.

write(depth=0, writer=<_io.TextIOWrapper name='<stdout>' mode='w' encoding='utf-8'>)[source]

Write this instance as either a Writable or as a Dictable. If class attribute _DICTABLE_WRITABLE_DESCENDANTS is set as True, then use the write() method on children instead of writing the generated dictionary. Otherwise, write this instance by first creating a dict recursively using asdict(), then formatting the output.

If the attribute _DICTABLE_WRITE_EXCLUDES is set, those attributes are removed from what is written in the write() method.

Note that this attribute will need to be set in all descendants in the instance hierarchy since writing the object instance graph is done recursively.

Parameters:
  • depth (int) – the starting indentation depth

  • writer (TextIOBase) – the writer to dump the content of this writable

exception zensols.lmtask.task.TaskDatasetFactoryError(message, prompt=None)[source]

Bases: TaskError

Raised when TaskDatasetFactory instances can not create datasets.

__module__ = 'zensols.lmtask.task'
exception zensols.lmtask.task.TaskError(message, prompt=None)[source]

Bases: APIError

Raised for any LLM specific error in this API.

__annotations__ = {}
__init__(message, prompt=None)[source]
__module__ = 'zensols.lmtask.task'
class zensols.lmtask.task.TaskFactory(config_factory, _task_pattern)[source]

Bases: Dictable

Creates instances of Task using create().

__init__(config_factory, _task_pattern)
config_factory: ConfigFactory

The factory that creates tasks.

create(name)[source]

Create a new instance of a task with name per the app config.

See:

task_names()

Return type:

Task

property task_names: Set[str]

The names of the tasks available to create with create().

write(depth=0, writer=<_io.TextIOWrapper name='<stdout>' mode='w' encoding='utf-8'>, short=False)[source]

Write this instance as either a Writable or as a Dictable. If class attribute _DICTABLE_WRITABLE_DESCENDANTS is set as True, then use the write() method on children instead of writing the generated dictionary. Otherwise, write this instance by first creating a dict recursively using asdict(), then formatting the output.

If the attribute _DICTABLE_WRITE_EXCLUDES is set, those attributes are removed from what is written in the write() method.

Note that this attribute will need to be set in all descendants in the instance hierarchy since writing the object instance graph is done recursively.

Parameters:
  • depth (int) – the starting indentation depth

  • writer (TextIOBase) – the writer to dump the content of this writable

class zensols.lmtask.task.TaskObject[source]

Bases: PersistableContainer, Dictable

Base class for task requests and responses.

__init__()
class zensols.lmtask.task.TaskRequest(model_input=None)[source]

Bases: TaskObject

The input request to the LLM via Task.process(). In most cases, obj:model_input can be used to skip the prompt compilation step.

__init__(model_input=None)
model_input: str = None

The text given verbatim to the model. This is some combination of both querty and prompt.

write(depth=0, writer=<_io.TextIOWrapper name='<stdout>' mode='w' encoding='utf-8'>)[source]

Write this instance as either a Writable or as a Dictable. If class attribute _DICTABLE_WRITABLE_DESCENDANTS is set as True, then use the write() method on children instead of writing the generated dictionary. Otherwise, write this instance by first creating a dict recursively using asdict(), then formatting the output.

If the attribute _DICTABLE_WRITE_EXCLUDES is set, those attributes are removed from what is written in the write() method.

Note that this attribute will need to be set in all descendants in the instance hierarchy since writing the object instance graph is done recursively.

Parameters:
  • depth (int) – the starting indentation depth

  • writer (TextIOBase) – the writer to dump the content of this writable

class zensols.lmtask.task.TaskResponse(request, model_output_raw, model_output)[source]

Bases: TaskObject

The happy-path response given by Task.

__init__(request, model_output_raw, model_output)
model_output: str

This task instance’s parsed response text given by the model.

model_output_raw: str

The model output verbatim.

request: TaskRequest

The request used to generated this response.

write(depth=0, writer=<_io.TextIOWrapper name='<stdout>' mode='w' encoding='utf-8'>, include_request=False, include_model_output=True, include_model_output_raw=False)[source]

Write this instance as either a Writable or as a Dictable. If class attribute _DICTABLE_WRITABLE_DESCENDANTS is set as True, then use the write() method on children instead of writing the generated dictionary. Otherwise, write this instance by first creating a dict recursively using asdict(), then formatting the output.

If the attribute _DICTABLE_WRITE_EXCLUDES is set, those attributes are removed from what is written in the write() method.

Note that this attribute will need to be set in all descendants in the instance hierarchy since writing the object instance graph is done recursively.

Parameters:
  • depth (int) – the starting indentation depth

  • writer (TextIOBase) – the writer to dump the content of this writable

zensols.lmtask.train module

Continued Pretraining and supervised fine-tuning training.

class zensols.lmtask.train.ModelResult(train_output, output_dir=None, train_params=None, config=None)[source]

Bases: Dictable

The trained model config, location and configuration used to train it.

__init__(train_output, output_dir=None, train_params=None, config=None)
config: Configurable = None

The application configuration used to configure the trainer.

property global_step: int

The global step from train_output.

property metrics: Dict[str, float]

Training metrics from train_output.

output_dir: Path = None

The directory of the models checkpoints.

train_output: TrainOutput

The output returned from the trainer.

train_params: Dict[str, Any] = None

The training parameters used to configure the trainer.

property training_loss: float

The training loss from train_output.

write(depth=0, writer=<_io.TextIOWrapper name='<stdout>' mode='w' encoding='utf-8'>, include_training_arguments=False, include_config=False)[source]

Write this instance as either a Writable or as a Dictable. If class attribute _DICTABLE_WRITABLE_DESCENDANTS is set as True, then use the write() method on children instead of writing the generated dictionary. Otherwise, write this instance by first creating a dict recursively using asdict(), then formatting the output.

If the attribute _DICTABLE_WRITE_EXCLUDES is set, those attributes are removed from what is written in the write() method.

Note that this attribute will need to be set in all descendants in the instance hierarchy since writing the object instance graph is done recursively.

Parameters:
  • depth (int) – the starting indentation depth

  • writer (TextIOBase) – the writer to dump the content of this writable

class zensols.lmtask.train.Trainer(config, resource, train_params, eval_params, train_source, eval_source, peft_output_dir, merged_output_dir)[source]

Bases: Dictable

An UnslothTrainer wrapper.

__init__(config, resource, train_params, eval_params, train_source, eval_source, peft_output_dir, merged_output_dir)
config: Configurable

Used to save to the model result.

eval_params: Dict[str, Any]

The evaluation parameters used to configure the trainer.

eval_source: TaskDatasetFactory

A factory that creates new datasets used to evaluation.

merged_output_dir: Union[str, Path]

The directory to save the base + perf in one model.

peft_output_dir: Union[str, Path]

The directory to save the Peft model.

resource: TrainerResource

Used to create the model and tokenzier.

train()[source]

Train the model.

Return type:

ModelResult

train_params: Dict[str, Any]

The training parameters used to configure the trainer.

train_source: TaskDatasetFactory

A factory that creates new datasets used to train using this instance.

write(depth=0, writer=<_io.TextIOWrapper name='<stdout>' mode='w' encoding='utf-8'>, include_training_arguments=False)[source]

Write this instance as either a Writable or as a Dictable. If class attribute _DICTABLE_WRITABLE_DESCENDANTS is set as True, then use the write() method on children instead of writing the generated dictionary. Otherwise, write this instance by first creating a dict recursively using asdict(), then formatting the output.

If the attribute _DICTABLE_WRITE_EXCLUDES is set, those attributes are removed from what is written in the write() method.

Note that this attribute will need to be set in all descendants in the instance hierarchy since writing the object instance graph is done recursively.

Parameters:
  • depth (int) – the starting indentation depth

  • writer (TextIOBase) – the writer to dump the content of this writable

class zensols.lmtask.train.TrainerResource(model_args=None, cache=True)[source]

Bases: Dictable, Primeable

Configures and instantiates the base mode, PEFT mode, and the tokenizer.

__init__(model_args=None, cache=True)
cache: bool = True

Whether to cache the tokenizer and model.

property model: PreTrainedModel

The base model.

model_args: Dict[str, Any] = None

The parameters that create the base model and tokenzier.

property peft_model: PeftModelForCausalLM

The PEFT (Parameter-Efficient Fine-Tuning) such as LoRA.

prime()[source]
property tokenizer: PreTrainedTokenizer

The base tokenizer.

Module contents