Publish Python SDK

Hello world!

Signed-off-by: Daniel Xu <dxu@dxuuu.xyz>
This commit is contained in:
Daniel Xu 2025-07-15 02:24:04 +00:00 committed by Daniel Xu
commit 829c151ba7
192 changed files with 25717 additions and 0 deletions

8
.devcontainer/Dockerfile Normal file
View file

@ -0,0 +1,8 @@
ARG VARIANT="3.9"
FROM mcr.microsoft.com/vscode/devcontainers/python:0-${VARIANT}
USER vscode
COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/
RUN echo "[[ -d .venv ]] && source .venv/bin/activate || export PATH=\$PATH" >> /home/vscode/.bashrc

View file

@ -0,0 +1,43 @@
// For format details, see https://aka.ms/devcontainer.json. For config options, see the
// README at: https://github.com/devcontainers/templates/tree/main/src/debian
{
"name": "Debian",
"build": {
"dockerfile": "Dockerfile",
"context": ".."
},
"postStartCommand": "uv sync --all-extras",
"customizations": {
"vscode": {
"extensions": [
"ms-python.python"
],
"settings": {
"terminal.integrated.shell.linux": "/bin/bash",
"python.pythonPath": ".venv/bin/python",
"python.defaultInterpreterPath": ".venv/bin/python",
"python.typeChecking": "basic",
"terminal.integrated.env.linux": {
"PATH": "${env:PATH}"
}
}
}
},
"features": {
"ghcr.io/devcontainers/features/node:1": {}
}
// Features to add to the dev container. More info: https://containers.dev/features.
// "features": {},
// Use 'forwardPorts' to make a list of ports inside the container available locally.
// "forwardPorts": [],
// Configure tool-specific properties.
// "customizations": {},
// Uncomment to connect as root instead. More info: https://aka.ms/dev-containers-non-root.
// "remoteUser": "root"
}

15
.gitignore vendored Normal file
View file

@ -0,0 +1,15 @@
.prism.log
_dev
__pycache__
.mypy_cache
dist
.venv
.idea
.env
.envrc
codegen.log
Brewfile.lock.json

1
.python-version Normal file
View file

@ -0,0 +1 @@
3.9.18

40
.ruff.toml Normal file
View file

@ -0,0 +1,40 @@
line-length = 100
include = []
exclude = []
force-exclude = true
[lint]
# Same as monorepo but removed UP007
select = [
# https://docs.astral.sh/ruff/rules
# pyflakes, pycodestyle, isort
"B905", # zip-without-explicit-strict
"F",
"E",
"W",
"I001",
"PIE804", # unnecessary-dict-kwargs
"PIE800", # unnecessary-spread
"PIE796", # non-unique-enums
"PIE794", # duplicate-class-field-definition
"PIE807", # reimplemented-container-builtin
"PIE810", #multiple-starts-ends-with
"FLY002",
"COM818",
"SIM",
"Q000",
#"UP007", # Use X | Y for type annotations
# Going to leave these 2 commented out for now as they play interestingly with chz.
# We could consider using them though!
# "TC001", # typing-only-first-party-import
# "TC002", # typing-only-third-party-import
"TC004", # runtime-import-in-type-checking-block
"TC005", # empty-type-checking-block
]
ignore = ["E501", "SIM108", "SIM117"]
unfixable = ["B905"]
extend-safe-fixes = ["TC004", "TC005"]

4
.stats.yml Normal file
View file

@ -0,0 +1,4 @@
configured_endpoints: 15
openapi_spec_url: https://storage.googleapis.com/stainless-sdk-openapi-specs/thinking-machines%2Ftinker-51afac49fce6fe0323489c3e809c5e2e50ce7384828362dcb0b7c2c7c9d77027.yml
openapi_spec_hash: ed32399a1f1754a8e66918aefbe9fd07
config_hash: 2ea282a8a1396267cb7c8a5e28467eb8

4
.sync_state Normal file
View file

@ -0,0 +1,4 @@
{
"last_synced_sha": "a3e70a3c6414c4a3d5bae16cc2a612ab65813dd4",
"last_sync_time": "2025-10-01T17:30:03.308998"
}

3
.vscode/settings.json vendored Normal file
View file

@ -0,0 +1,3 @@
{
"python.analysis.importFormat": "relative",
}

1
Brewfile Normal file
View file

@ -0,0 +1 @@
brew "uv"

128
CONTRIBUTING.md Normal file
View file

@ -0,0 +1,128 @@
## Setting up the environment
### With `uv`
We use [uv](https://docs.astral.sh/uv/) to manage dependencies because it will automatically provision a Python environment with the expected Python version. To set it up, run:
```sh
$ ./scripts/bootstrap
```
Or [install uv manually](https://docs.astral.sh/uv/getting-started/installation/) and run:
```sh
$ uv sync --all-extras
```
You can then run scripts using `uv run python script.py` or by manually activating the virtual environment:
```sh
# manually activate - https://docs.python.org/3/library/venv.html#how-venvs-work
$ source .venv/bin/activate
# now you can omit the `uv run` prefix
$ python script.py
```
### Without `uv`
Alternatively if you don't want to install `uv`, you can stick with the standard `pip` setup by ensuring you have the Python version specified in `.python-version`, create a virtual environment however you desire and then install dependencies using this command:
```sh
$ pip install -r requirements-dev.lock
```
## Modifying/Adding code
Most of the SDK is generated code. Modifications to code will be persisted between generations, but may
result in merge conflicts between manual patches and changes from the generator. The generator will never
modify the contents of the `src/tinker/lib/` and `examples/` directories.
## Adding and running examples
All files in the `examples/` directory are not modified by the generator and can be freely edited or added to.
```py
# add an example to examples/<your-example>.py
#!/usr/bin/env -S uv run python
```
```sh
$ chmod +x examples/<your-example>.py
# run the example against your api
$ ./examples/<your-example>.py
```
## Using the repository from source
If youd like to use the repository from source, you can either install from git or link to a cloned repository:
To install via git:
```sh
$ pip install git+ssh://git@github.com/stainless-sdks/tinker-python.git
```
Alternatively, you can build from source and install the wheel file:
Building this package will create two files in the `dist/` directory, a `.tar.gz` containing the source files and a `.whl` that can be used to install the package efficiently.
To create a distributable version of the library, all you have to do is run this command:
```sh
$ uv build
# or
$ python -m build
```
Then to install:
```sh
$ pip install ./path-to-wheel-file.whl
```
## Running tests
Most tests require you to [set up a mock server](https://github.com/stoplightio/prism) against the OpenAPI spec to run the tests.
```sh
# you will need npm installed
$ npx prism mock path/to/your/openapi.yml
```
```sh
$ ./scripts/test
```
## Linting and formatting
This repository uses [ruff](https://github.com/astral-sh/ruff) and
[black](https://github.com/psf/black) to format the code in the repository.
To lint:
```sh
$ ./scripts/lint
```
To format and fix all ruff issues automatically:
```sh
$ ./scripts/format
```
## Publishing and releases
Changes made to this repository via the automated release PR pipeline should publish to PyPI automatically. If
the changes aren't made through the automated pipeline, you may want to make releases manually.
### Publish with a GitHub workflow
You can release to package managers by using [the `Publish PyPI` GitHub action](https://www.github.com/stainless-sdks/tinker-python/actions/workflows/publish-pypi.yml). This requires a setup organization or repository secret to be set up.
### Publish manually
If you need to manually release a package, you can run the `bin/publish-pypi` script with a `PYPI_TOKEN` set on
the environment.

201
LICENSE Normal file
View file

@ -0,0 +1,201 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright 2025 Tinker
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

7
README.md Normal file
View file

@ -0,0 +1,7 @@
<h1 align="center">Tinker Python SDK</h1>
<div align="center">
<img src="docs/images/logo.svg" width="60%" />
Documentation:
<a href="http://tinker-docs.thinkingmachines.ai/">tinker-docs.thinkingmachines.ai</a>
</div>

23
SECURITY.md Normal file
View file

@ -0,0 +1,23 @@
# Security Policy
## Reporting Security Issues
This SDK is generated by [Stainless Software Inc](http://stainless.com). Stainless takes security seriously, and encourages you to report any security vulnerability promptly so that appropriate action can be taken.
To report a security issue, please contact the Stainless team at security@stainless.com.
## Responsible Disclosure
We appreciate the efforts of security researchers and individuals who help us maintain the security of
SDKs we generate. If you believe you have found a security vulnerability, please adhere to responsible
disclosure practices by allowing us a reasonable amount of time to investigate and address the issue
before making any information public.
## Reporting Non-SDK Related Security Issues
If you encounter security issues that are not directly related to SDKs but pertain to the services
or products provided by Tinker, please follow the respective company's security reporting guidelines.
---
Thank you for helping us keep the SDKs and systems they interact with secure.

7
bin/publish-pypi Normal file
View file

@ -0,0 +1,7 @@
#!/usr/bin/env bash
set -eux
rm -rf dist
mkdir -p dist
uv build
uv publish --token=$PYPI_TOKEN

BIN
docs/images/logo.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 319 KiB

967
docs/images/logo.svg Normal file
View file

@ -0,0 +1,967 @@
<svg width="820" height="512" viewBox="0 0 820 512" fill="none" xmlns="http://www.w3.org/2000/svg">
<rect width="820" height="512" fill="white"/>
<rect x="618.852" y="46" width="2" height="48" transform="rotate(-45 618.852 46)" fill="#FFCE52"/>
<rect x="679.797" y="44.5859" width="2" height="48" transform="rotate(45 679.797 44.5859)" fill="#FFCE52"/>
<rect x="233" y="75" width="120" height="2" fill="#73BB83"/>
<rect x="463" y="75" width="120" height="2" fill="#73BB83"/>
<rect x="233" y="195" width="120" height="2" fill="#73BB83"/>
<rect x="463" y="195" width="120" height="2" fill="#73BB83"/>
<rect x="110" y="75" width="120" height="2" fill="#FF9D52"/>
<rect x="140" y="45" width="60" height="2" fill="#FF9D52"/>
<rect x="380" y="45" width="60" height="2" fill="#FF9D52"/>
<rect x="620" y="45" width="60" height="2" fill="#FF9D52"/>
<rect x="110" y="135" width="120" height="2" fill="#FF9D52"/>
<rect x="110" y="195" width="120" height="2" fill="#FF9D52"/>
<rect x="169" y="74" width="2" height="120" fill="#4196D8"/>
<rect x="230.805" y="134" width="2" height="90" transform="rotate(45 230.805 134)" fill="#D34F4F"/>
<rect x="174.234" y="80.4141" width="2" height="80" transform="rotate(-45 174.234 80.4141)" fill="#D34F4F"/>
<rect x="235.234" y="82.4141" width="2" height="160" transform="rotate(-45 235.234 82.4141)" fill="#D34F4F"/>
<rect x="235.234" y="202.414" width="2" height="160" transform="rotate(-45 235.234 202.414)" fill="#D34F4F"/>
<rect x="115.234" y="202.414" width="2" height="160" transform="rotate(-45 115.234 202.414)" fill="#4196D8"/>
<rect x="355.234" y="202.414" width="2" height="160" transform="rotate(-45 355.234 202.414)" fill="#4196D8"/>
<rect x="595.234" y="202.414" width="2" height="160" transform="rotate(-45 595.234 202.414)" fill="#4196D8"/>
<rect x="475.234" y="202.414" width="2" height="160" transform="rotate(-45 475.234 202.414)" fill="#D34F4F"/>
<rect x="473.234" y="80.4141" width="2" height="170" transform="rotate(-45 473.234 80.4141)" fill="#D34F4F"/>
<rect x="109.234" y="137" width="2" height="80" transform="rotate(-45 109.234 137)" fill="#D34F4F"/>
<rect x="167.805" y="75.5859" width="2" height="80" transform="rotate(45 167.805 75.5859)" fill="#D34F4F"/>
<rect x="139.797" y="44.5859" width="2" height="48" transform="rotate(45 139.797 44.5859)" fill="#FFCE52"/>
<rect x="138.852" y="46" width="2" height="48" transform="rotate(-45 138.852 46)" fill="#FFCE52"/>
<rect x="378.852" y="46" width="2" height="48" transform="rotate(-45 378.852 46)" fill="#FFCE52"/>
<rect x="198.844" y="46" width="2" height="48" transform="rotate(-45 198.844 46)" fill="#FFCE52"/>
<rect x="438.844" y="46" width="2" height="48" transform="rotate(-45 438.844 46)" fill="#FFCE52"/>
<rect x="678.844" y="46" width="2" height="48" transform="rotate(-45 678.844 46)" fill="#FFCE52"/>
<rect x="199.797" y="44.5859" width="2" height="48" transform="rotate(45 199.797 44.5859)" fill="#FFCE52"/>
<rect x="379.797" y="44.5859" width="2" height="48" transform="rotate(45 379.797 44.5859)" fill="#FFCE52"/>
<rect x="439.797" y="44.5859" width="2" height="48" transform="rotate(45 439.797 44.5859)" fill="#FFCE52"/>
<rect x="619.797" y="44.5859" width="2" height="48" transform="rotate(45 619.797 44.5859)" fill="#FFCE52"/>
<rect x="348.805" y="75.5859" width="2" height="170" transform="rotate(45 348.805 75.5859)" fill="#D34F4F"/>
<rect x="348.805" y="195.586" width="2" height="170" transform="rotate(45 348.805 195.586)" fill="#D34F4F"/>
<rect x="588.805" y="195.586" width="2" height="170" transform="rotate(45 588.805 195.586)" fill="#D34F4F"/>
<rect x="588.805" y="75.5859" width="2" height="170" transform="rotate(45 588.805 75.5859)" fill="#D34F4F"/>
<rect x="109" y="74" width="2" height="120" fill="#FF9D52"/>
<rect x="229" y="74" width="2" height="120" fill="#FF9D52"/>
<circle cx="110" cy="76" r="7.5" fill="url(#paint0_radial_781_319508)" stroke="white"/>
<circle cx="110" cy="76" r="1" fill="black" fill-opacity="0.4"/>
<circle cx="140" cy="76" r="7.5" fill="url(#paint1_radial_781_319508)" stroke="white"/>
<circle cx="140" cy="76" r="1" fill="black" fill-opacity="0.4"/>
<circle cx="170" cy="76" r="7.5" fill="url(#paint2_radial_781_319508)" stroke="white"/>
<circle cx="170" cy="76" r="1" fill="black" fill-opacity="0.4"/>
<circle cx="200" cy="76" r="7.5" fill="url(#paint3_radial_781_319508)" stroke="white"/>
<circle cx="200" cy="76" r="1" fill="black" fill-opacity="0.4"/>
<circle cx="230" cy="76" r="7.5" fill="url(#paint4_radial_781_319508)" stroke="white"/>
<circle cx="230" cy="76" r="1" fill="black" fill-opacity="0.4"/>
<circle cx="140" cy="136" r="7.5" fill="url(#paint5_radial_781_319508)" stroke="white"/>
<circle cx="140" cy="136" r="1" fill="black" fill-opacity="0.4"/>
<circle cx="200" cy="136" r="7.5" fill="url(#paint6_radial_781_319508)" stroke="white"/>
<circle cx="200" cy="136" r="1" fill="black" fill-opacity="0.4"/>
<circle cx="170" cy="136" r="7.5" fill="url(#paint7_radial_781_319508)" stroke="white"/>
<circle cx="170" cy="136" r="1" fill="black" fill-opacity="0.4"/>
<circle cx="230" cy="106" r="7.5" fill="url(#paint8_radial_781_319508)" stroke="white"/>
<circle cx="230" cy="106" r="1" fill="black" fill-opacity="0.4"/>
<circle cx="110" cy="106" r="7.5" fill="url(#paint9_radial_781_319508)" stroke="white"/>
<circle cx="110" cy="106" r="1" fill="black" fill-opacity="0.4"/>
<circle cx="230" cy="136" r="7.5" fill="url(#paint10_radial_781_319508)" stroke="white"/>
<circle cx="230" cy="136" r="1" fill="black" fill-opacity="0.4"/>
<circle cx="110" cy="136" r="7.5" fill="url(#paint11_radial_781_319508)" stroke="white"/>
<circle cx="110" cy="136" r="1" fill="black" fill-opacity="0.4"/>
<circle cx="230" cy="166" r="7.5" fill="url(#paint12_radial_781_319508)" stroke="white"/>
<circle cx="230" cy="166" r="1" fill="black" fill-opacity="0.4"/>
<circle cx="110" cy="166" r="7.5" fill="url(#paint13_radial_781_319508)" stroke="white"/>
<circle cx="110" cy="166" r="1" fill="black" fill-opacity="0.4"/>
<circle cx="290" cy="136" r="7.5" fill="url(#paint14_radial_781_319508)" stroke="white"/>
<circle cx="290" cy="136" r="1" fill="black" fill-opacity="0.4"/>
<circle cx="290" cy="256" r="7.5" fill="url(#paint15_radial_781_319508)" stroke="white"/>
<circle cx="290" cy="256" r="1" fill="black" fill-opacity="0.4"/>
<circle cx="530" cy="256" r="7.5" fill="url(#paint16_radial_781_319508)" stroke="white"/>
<circle cx="530" cy="256" r="1" fill="black" fill-opacity="0.4"/>
<rect x="350" y="75" width="120" height="2" fill="#FF9D52"/>
<rect x="350" y="135" width="120" height="2" fill="#FF9D52"/>
<rect x="350" y="195" width="120" height="2" fill="#FF9D52"/>
<rect x="409" y="74" width="2" height="120" fill="#4196D8"/>
<rect x="470.805" y="134" width="2" height="80" transform="rotate(45 470.805 134)" fill="#D34F4F"/>
<rect x="414.234" y="81.4141" width="2" height="80" transform="rotate(-45 414.234 81.4141)" fill="#D34F4F"/>
<rect x="349.234" y="137" width="2" height="80" transform="rotate(-45 349.234 137)" fill="#D34F4F"/>
<rect x="408.789" y="75.5859" width="2" height="80" transform="rotate(45 408.789 75.5859)" fill="#D34F4F"/>
<rect x="349" y="74" width="2" height="120" fill="#FF9D52"/>
<rect x="349" y="204" width="2" height="120" fill="#73BB83"/>
<rect x="229" y="204" width="2" height="120" fill="#73BB83"/>
<rect x="109" y="204" width="2" height="120" fill="#73BB83"/>
<rect x="469" y="204" width="2" height="120" fill="#73BB83"/>
<rect x="709" y="204" width="2" height="120" fill="#73BB83"/>
<rect x="589" y="204" width="2" height="120" fill="#73BB83"/>
<rect x="469" y="74" width="2" height="120" fill="#FF9D52"/>
<circle cx="350" cy="76" r="7.5" fill="url(#paint17_radial_781_319508)" stroke="white"/>
<circle cx="350" cy="76" r="1" fill="black" fill-opacity="0.4"/>
<circle cx="380" cy="76" r="7.5" fill="url(#paint18_radial_781_319508)" stroke="white"/>
<circle cx="380" cy="76" r="1" fill="black" fill-opacity="0.4"/>
<circle cx="410" cy="76" r="7.5" fill="url(#paint19_radial_781_319508)" stroke="white"/>
<circle cx="410" cy="76" r="1" fill="black" fill-opacity="0.4"/>
<circle cx="440" cy="76" r="7.5" fill="url(#paint20_radial_781_319508)" stroke="white"/>
<circle cx="440" cy="76" r="1" fill="black" fill-opacity="0.4"/>
<circle cx="470" cy="76" r="7.5" fill="url(#paint21_radial_781_319508)" stroke="white"/>
<circle cx="470" cy="76" r="1" fill="black" fill-opacity="0.4"/>
<circle cx="380" cy="136" r="7.5" fill="url(#paint22_radial_781_319508)" stroke="white"/>
<circle cx="380" cy="136" r="1" fill="black" fill-opacity="0.4"/>
<circle cx="440" cy="136" r="7.5" fill="url(#paint23_radial_781_319508)" stroke="white"/>
<circle cx="440" cy="136" r="1" fill="black" fill-opacity="0.4"/>
<circle cx="410" cy="136" r="7.5" fill="url(#paint24_radial_781_319508)" stroke="white"/>
<circle cx="410" cy="136" r="1" fill="black" fill-opacity="0.4"/>
<circle cx="470" cy="106" r="7.5" fill="url(#paint25_radial_781_319508)" stroke="white"/>
<circle cx="470" cy="106" r="1" fill="black" fill-opacity="0.4"/>
<circle cx="350" cy="106" r="7.5" fill="url(#paint26_radial_781_319508)" stroke="white"/>
<circle cx="350" cy="106" r="1" fill="black" fill-opacity="0.4"/>
<circle cx="470" cy="136" r="7.5" fill="url(#paint27_radial_781_319508)" stroke="white"/>
<circle cx="470" cy="136" r="1" fill="black" fill-opacity="0.4"/>
<circle cx="350" cy="136" r="7.5" fill="url(#paint28_radial_781_319508)" stroke="white"/>
<circle cx="350" cy="136" r="1" fill="black" fill-opacity="0.4"/>
<circle cx="470" cy="166" r="7.5" fill="url(#paint29_radial_781_319508)" stroke="white"/>
<circle cx="470" cy="166" r="1" fill="black" fill-opacity="0.4"/>
<circle cx="350" cy="166" r="7.5" fill="url(#paint30_radial_781_319508)" stroke="white"/>
<circle cx="350" cy="166" r="1" fill="black" fill-opacity="0.4"/>
<circle cx="350" cy="196" r="7.5" fill="url(#paint31_radial_781_319508)" stroke="white"/>
<circle cx="350" cy="196" r="1" fill="black" fill-opacity="0.4"/>
<circle cx="380" cy="196" r="7.5" fill="url(#paint32_radial_781_319508)" stroke="white"/>
<circle cx="380" cy="196" r="1" fill="black" fill-opacity="0.4"/>
<circle cx="410" cy="196" r="7.5" fill="url(#paint33_radial_781_319508)" stroke="white"/>
<circle cx="410" cy="196" r="1" fill="black" fill-opacity="0.4"/>
<circle cx="440" cy="196" r="7.5" fill="url(#paint34_radial_781_319508)" stroke="white"/>
<circle cx="440" cy="196" r="1" fill="black" fill-opacity="0.4"/>
<circle cx="470" cy="196" r="7.5" fill="url(#paint35_radial_781_319508)" stroke="white"/>
<circle cx="470" cy="196" r="1" fill="black" fill-opacity="0.4"/>
<circle cx="530" cy="136" r="7.5" fill="url(#paint36_radial_781_319508)" stroke="white"/>
<circle cx="530" cy="136" r="1" fill="black" fill-opacity="0.4"/>
<rect x="590" y="75" width="120" height="2" fill="#FF9D52"/>
<rect x="590" y="135" width="120" height="2" fill="#FF9D52"/>
<rect x="590" y="195" width="120" height="2" fill="#FF9D52"/>
<rect x="649" y="74" width="2" height="120" fill="#4196D8"/>
<rect x="710.805" y="134" width="2" height="80" transform="rotate(45 710.805 134)" fill="#D34F4F"/>
<rect x="654.234" y="81.4141" width="2" height="80" transform="rotate(-45 654.234 81.4141)" fill="#D34F4F"/>
<rect x="589.234" y="137" width="2" height="80" transform="rotate(-45 589.234 137)" fill="#D34F4F"/>
<rect x="646.086" y="78.5859" width="2" height="80" transform="rotate(45 646.086 78.5859)" fill="#D34F4F"/>
<rect x="589" y="74" width="2" height="120" fill="#FF9D52"/>
<rect x="709" y="74" width="2" height="120" fill="#FF9D52"/>
<circle cx="590" cy="76" r="7.5" fill="url(#paint37_radial_781_319508)" stroke="white"/>
<circle cx="590" cy="76" r="1" fill="black" fill-opacity="0.4"/>
<circle cx="620" cy="76" r="7.5" fill="url(#paint38_radial_781_319508)" stroke="white"/>
<circle cx="620" cy="76" r="1" fill="black" fill-opacity="0.4"/>
<circle cx="650" cy="76" r="7.5" fill="url(#paint39_radial_781_319508)" stroke="white"/>
<circle cx="650" cy="76" r="1" fill="black" fill-opacity="0.4"/>
<circle cx="680" cy="76" r="7.5" fill="url(#paint40_radial_781_319508)" stroke="white"/>
<circle cx="680" cy="76" r="1" fill="black" fill-opacity="0.4"/>
<circle cx="710" cy="76" r="7.5" fill="url(#paint41_radial_781_319508)" stroke="white"/>
<circle cx="710" cy="76" r="1" fill="black" fill-opacity="0.4"/>
<circle cx="620" cy="136" r="7.5" fill="url(#paint42_radial_781_319508)" stroke="white"/>
<circle cx="620" cy="136" r="1" fill="black" fill-opacity="0.4"/>
<circle cx="680" cy="136" r="7.5" fill="url(#paint43_radial_781_319508)" stroke="white"/>
<circle cx="680" cy="136" r="1" fill="black" fill-opacity="0.4"/>
<circle cx="650" cy="136" r="7.5" fill="url(#paint44_radial_781_319508)" stroke="white"/>
<circle cx="650" cy="136" r="1" fill="black" fill-opacity="0.4"/>
<circle cx="710" cy="106" r="7.5" fill="url(#paint45_radial_781_319508)" stroke="white"/>
<circle cx="710" cy="106" r="1" fill="black" fill-opacity="0.4"/>
<circle cx="590" cy="106" r="7.5" fill="url(#paint46_radial_781_319508)" stroke="white"/>
<circle cx="590" cy="106" r="1" fill="black" fill-opacity="0.4"/>
<circle cx="710" cy="136" r="7.5" fill="url(#paint47_radial_781_319508)" stroke="white"/>
<circle cx="710" cy="136" r="1" fill="black" fill-opacity="0.4"/>
<circle cx="590" cy="136" r="7.5" fill="url(#paint48_radial_781_319508)" stroke="white"/>
<circle cx="590" cy="136" r="1" fill="black" fill-opacity="0.4"/>
<circle cx="710" cy="166" r="7.5" fill="url(#paint49_radial_781_319508)" stroke="white"/>
<circle cx="710" cy="166" r="1" fill="black" fill-opacity="0.4"/>
<circle cx="590" cy="166" r="7.5" fill="url(#paint50_radial_781_319508)" stroke="white"/>
<circle cx="590" cy="166" r="1" fill="black" fill-opacity="0.4"/>
<rect x="59" y="134" width="2" height="250" fill="#FF9D52"/>
<rect x="113.195" y="199.129" width="2" height="80" transform="rotate(140.469 113.195 199.129)" fill="#FFCE52"/>
<rect x="59.5391" y="137.977" width="2" height="80" transform="rotate(-140.47 59.5391 137.977)" fill="#FFCE52"/>
<rect x="60.5391" y="256.977" width="2" height="80" transform="rotate(-140.47 60.5391 256.977)" fill="#FFCE52"/>
<rect x="61.5391" y="375.977" width="2" height="80" transform="rotate(-140.47 61.5391 375.977)" fill="#FFCE52"/>
<rect x="113.195" y="439.129" width="2" height="80" transform="rotate(140.469 113.195 439.129)" fill="#FFCE52"/>
<rect x="113.195" y="319.129" width="2" height="80" transform="rotate(140.469 113.195 319.129)" fill="#FFCE52"/>
<circle cx="60" cy="136" r="7.5" fill="url(#paint51_radial_781_319508)" stroke="white"/>
<circle cx="60" cy="136" r="1" fill="black" fill-opacity="0.4"/>
<circle cx="60" cy="376" r="7.5" fill="url(#paint52_radial_781_319508)" stroke="white"/>
<circle cx="60" cy="376" r="1" fill="black" fill-opacity="0.4"/>
<circle cx="60" cy="256" r="7.5" fill="url(#paint53_radial_781_319508)" stroke="white"/>
<circle cx="60" cy="256" r="1" fill="black" fill-opacity="0.4"/>
<rect x="761.461" y="367.383" width="2" height="230" transform="rotate(-180 761.461 367.383)" fill="#FF9D52"/>
<rect x="758.922" y="375.402" width="2" height="80" transform="rotate(39.53 758.922 375.402)" fill="#FFCE52"/>
<rect x="757.922" y="256.402" width="2" height="80" transform="rotate(39.53 757.922 256.402)" fill="#FFCE52"/>
<rect x="756.922" y="137.406" width="2" height="80" transform="rotate(39.53 756.922 137.406)" fill="#FFCE52"/>
<rect x="707.266" y="192.25" width="2" height="80" transform="rotate(-39.5305 707.266 192.25)" fill="#FFCE52"/>
<circle cx="110" cy="196" r="7.5" fill="url(#paint54_radial_781_319508)" stroke="white"/>
<circle cx="110" cy="196" r="1" fill="black" fill-opacity="0.4"/>
<circle cx="140" cy="196" r="7.5" fill="url(#paint55_radial_781_319508)" stroke="white"/>
<circle cx="140" cy="196" r="1" fill="black" fill-opacity="0.4"/>
<circle cx="170" cy="196" r="7.5" fill="url(#paint56_radial_781_319508)" stroke="white"/>
<circle cx="170" cy="196" r="1" fill="black" fill-opacity="0.4"/>
<circle cx="200" cy="196" r="7.5" fill="url(#paint57_radial_781_319508)" stroke="white"/>
<circle cx="200" cy="196" r="1" fill="black" fill-opacity="0.4"/>
<circle cx="230" cy="196" r="7.5" fill="url(#paint58_radial_781_319508)" stroke="white"/>
<circle cx="230" cy="196" r="1" fill="black" fill-opacity="0.4"/>
<rect x="707.266" y="312.25" width="2" height="80" transform="rotate(-39.5305 707.266 312.25)" fill="#FFCE52"/>
<rect x="707.266" y="72.2539" width="2" height="80" transform="rotate(-39.5305 707.266 72.2539)" fill="#FFCE52"/>
<rect x="707.266" y="72.2539" width="2" height="80" transform="rotate(-39.5305 707.266 72.2539)" fill="#FFCE52"/>
<circle cx="760.461" cy="375.379" r="7.5" transform="rotate(-180 760.461 375.379)" fill="url(#paint59_radial_781_319508)" stroke="white"/>
<circle cx="760.461" cy="375.379" r="1" transform="rotate(-180 760.461 375.379)" fill="black" fill-opacity="0.4"/>
<circle cx="760.461" cy="255.379" r="7.5" transform="rotate(-180 760.461 255.379)" fill="url(#paint60_radial_781_319508)" stroke="white"/>
<circle cx="760.461" cy="255.379" r="1" transform="rotate(-180 760.461 255.379)" fill="black" fill-opacity="0.4"/>
<circle cx="760.461" cy="135.383" r="7.5" transform="rotate(-180 760.461 135.383)" fill="url(#paint61_radial_781_319508)" stroke="white"/>
<circle cx="760.461" cy="135.383" r="1" transform="rotate(-180 760.461 135.383)" fill="black" fill-opacity="0.4"/>
<circle cx="140" cy="46" r="7.5" fill="url(#paint62_radial_781_319508)" stroke="white"/>
<circle cx="140" cy="46" r="1" fill="black" fill-opacity="0.4"/>
<circle cx="380" cy="46" r="7.5" fill="url(#paint63_radial_781_319508)" stroke="white"/>
<circle cx="380" cy="46" r="1" fill="black" fill-opacity="0.4"/>
<circle cx="620" cy="46" r="7.5" fill="url(#paint64_radial_781_319508)" stroke="white"/>
<circle cx="620" cy="46" r="1" fill="black" fill-opacity="0.4"/>
<circle cx="200" cy="46" r="7.5" fill="url(#paint65_radial_781_319508)" stroke="white"/>
<circle cx="200" cy="46" r="1" fill="black" fill-opacity="0.4"/>
<circle cx="440" cy="46" r="7.5" fill="url(#paint66_radial_781_319508)" stroke="white"/>
<circle cx="440" cy="46" r="1" fill="black" fill-opacity="0.4"/>
<circle cx="680" cy="46" r="7.5" fill="url(#paint67_radial_781_319508)" stroke="white"/>
<circle cx="680" cy="46" r="1" fill="black" fill-opacity="0.4"/>
<rect x="201.148" y="466" width="2" height="48" transform="rotate(135 201.148 466)" fill="#FFCE52"/>
<rect x="140.203" y="467.414" width="2" height="48" transform="rotate(-135 140.203 467.414)" fill="#FFCE52"/>
<rect x="587" y="437" width="120" height="2" transform="rotate(-180 587 437)" fill="#73BB83"/>
<rect x="357" y="437" width="120" height="2" transform="rotate(-180 357 437)" fill="#73BB83"/>
<rect x="587" y="317" width="120" height="2" transform="rotate(-180 587 317)" fill="#73BB83"/>
<rect x="357" y="317" width="120" height="2" transform="rotate(-180 357 317)" fill="#73BB83"/>
<rect x="710" y="437" width="120" height="2" transform="rotate(-180 710 437)" fill="#FF9D52"/>
<rect x="680" y="467" width="60" height="2" transform="rotate(-180 680 467)" fill="#FF9D52"/>
<rect x="440" y="467" width="60" height="2" transform="rotate(-180 440 467)" fill="#FF9D52"/>
<rect x="200" y="467" width="60" height="2" transform="rotate(-180 200 467)" fill="#FF9D52"/>
<rect x="710" y="377" width="120" height="2" transform="rotate(-180 710 377)" fill="#FF9D52"/>
<rect x="710" y="317" width="120" height="2" transform="rotate(-180 710 317)" fill="#FF9D52"/>
<rect x="651" y="438" width="2" height="120" transform="rotate(-180 651 438)" fill="#4196D8"/>
<rect x="589.195" y="378" width="2" height="90" transform="rotate(-135 589.195 378)" fill="#D34F4F"/>
<rect x="645.766" y="431.586" width="2" height="80" transform="rotate(135 645.766 431.586)" fill="#D34F4F"/>
<rect x="584.766" y="429.586" width="2" height="160" transform="rotate(135 584.766 429.586)" fill="#D34F4F"/>
<rect x="346.766" y="431.586" width="2" height="170" transform="rotate(135 346.766 431.586)" fill="#D34F4F"/>
<rect x="710.766" y="375" width="2" height="80" transform="rotate(135 710.766 375)" fill="#D34F4F"/>
<rect x="652.195" y="436.414" width="2" height="80" transform="rotate(-135 652.195 436.414)" fill="#D34F4F"/>
<rect x="680.203" y="467.414" width="2" height="48" transform="rotate(-135 680.203 467.414)" fill="#FFCE52"/>
<rect x="681.148" y="466" width="2" height="48" transform="rotate(135 681.148 466)" fill="#FFCE52"/>
<rect x="441.148" y="466" width="2" height="48" transform="rotate(135 441.148 466)" fill="#FFCE52"/>
<rect x="621.156" y="466" width="2" height="48" transform="rotate(135 621.156 466)" fill="#FFCE52"/>
<rect x="381.156" y="466" width="2" height="48" transform="rotate(135 381.156 466)" fill="#FFCE52"/>
<rect x="141.156" y="466" width="2" height="48" transform="rotate(135 141.156 466)" fill="#FFCE52"/>
<rect x="620.203" y="467.414" width="2" height="48" transform="rotate(-135 620.203 467.414)" fill="#FFCE52"/>
<rect x="440.203" y="467.414" width="2" height="48" transform="rotate(-135 440.203 467.414)" fill="#FFCE52"/>
<rect x="380.203" y="467.414" width="2" height="48" transform="rotate(-135 380.203 467.414)" fill="#FFCE52"/>
<rect x="200.203" y="467.414" width="2" height="48" transform="rotate(-135 200.203 467.414)" fill="#FFCE52"/>
<rect x="471.195" y="436.414" width="2" height="170" transform="rotate(-135 471.195 436.414)" fill="#D34F4F"/>
<rect x="231.195" y="436.414" width="2" height="170" transform="rotate(-135 231.195 436.414)" fill="#D34F4F"/>
<rect x="711" y="438" width="2" height="120" transform="rotate(-180 711 438)" fill="#FF9D52"/>
<rect x="591" y="438" width="2" height="120" transform="rotate(-180 591 438)" fill="#FF9D52"/>
<circle cx="710" cy="436" r="7.5" transform="rotate(-180 710 436)" fill="url(#paint68_radial_781_319508)" stroke="white"/>
<circle cx="710" cy="436" r="1" transform="rotate(-180 710 436)" fill="black" fill-opacity="0.4"/>
<circle cx="680" cy="436" r="7.5" transform="rotate(-180 680 436)" fill="url(#paint69_radial_781_319508)" stroke="white"/>
<circle cx="680" cy="436" r="1" transform="rotate(-180 680 436)" fill="black" fill-opacity="0.4"/>
<circle cx="650" cy="436" r="7.5" transform="rotate(-180 650 436)" fill="url(#paint70_radial_781_319508)" stroke="white"/>
<circle cx="650" cy="436" r="1" transform="rotate(-180 650 436)" fill="black" fill-opacity="0.4"/>
<circle cx="620" cy="436" r="7.5" transform="rotate(-180 620 436)" fill="url(#paint71_radial_781_319508)" stroke="white"/>
<circle cx="620" cy="436" r="1" transform="rotate(-180 620 436)" fill="black" fill-opacity="0.4"/>
<circle cx="590" cy="436" r="7.5" transform="rotate(-180 590 436)" fill="url(#paint72_radial_781_319508)" stroke="white"/>
<circle cx="590" cy="436" r="1" transform="rotate(-180 590 436)" fill="black" fill-opacity="0.4"/>
<circle cx="680" cy="376" r="7.5" transform="rotate(-180 680 376)" fill="url(#paint73_radial_781_319508)" stroke="white"/>
<circle cx="680" cy="376" r="1" transform="rotate(-180 680 376)" fill="black" fill-opacity="0.4"/>
<circle cx="620" cy="376" r="7.5" transform="rotate(-180 620 376)" fill="url(#paint74_radial_781_319508)" stroke="white"/>
<circle cx="620" cy="376" r="1" transform="rotate(-180 620 376)" fill="black" fill-opacity="0.4"/>
<circle cx="650" cy="376" r="7.5" transform="rotate(-180 650 376)" fill="url(#paint75_radial_781_319508)" stroke="white"/>
<circle cx="650" cy="376" r="1" transform="rotate(-180 650 376)" fill="black" fill-opacity="0.4"/>
<circle cx="590" cy="406" r="7.5" transform="rotate(-180 590 406)" fill="url(#paint76_radial_781_319508)" stroke="white"/>
<circle cx="590" cy="406" r="1" transform="rotate(-180 590 406)" fill="black" fill-opacity="0.4"/>
<circle cx="710" cy="406" r="7.5" transform="rotate(-180 710 406)" fill="url(#paint77_radial_781_319508)" stroke="white"/>
<circle cx="710" cy="406" r="1" transform="rotate(-180 710 406)" fill="black" fill-opacity="0.4"/>
<circle cx="590" cy="376" r="7.5" transform="rotate(-180 590 376)" fill="url(#paint78_radial_781_319508)" stroke="white"/>
<circle cx="590" cy="376" r="1" transform="rotate(-180 590 376)" fill="black" fill-opacity="0.4"/>
<circle cx="710" cy="376" r="7.5" transform="rotate(-180 710 376)" fill="url(#paint79_radial_781_319508)" stroke="white"/>
<circle cx="710" cy="376" r="1" transform="rotate(-180 710 376)" fill="black" fill-opacity="0.4"/>
<circle cx="590" cy="346" r="7.5" transform="rotate(-180 590 346)" fill="url(#paint80_radial_781_319508)" stroke="white"/>
<circle cx="590" cy="346" r="1" transform="rotate(-180 590 346)" fill="black" fill-opacity="0.4"/>
<circle cx="710" cy="346" r="7.5" transform="rotate(-180 710 346)" fill="url(#paint81_radial_781_319508)" stroke="white"/>
<circle cx="710" cy="346" r="1" transform="rotate(-180 710 346)" fill="black" fill-opacity="0.4"/>
<circle cx="710" cy="316" r="7.5" transform="rotate(-180 710 316)" fill="url(#paint82_radial_781_319508)" stroke="white"/>
<circle cx="710" cy="316" r="1" transform="rotate(-180 710 316)" fill="black" fill-opacity="0.4"/>
<circle cx="680" cy="316" r="7.5" transform="rotate(-180 680 316)" fill="url(#paint83_radial_781_319508)" stroke="white"/>
<circle cx="680" cy="316" r="1" transform="rotate(-180 680 316)" fill="black" fill-opacity="0.4"/>
<circle cx="650" cy="316" r="7.5" transform="rotate(-180 650 316)" fill="url(#paint84_radial_781_319508)" stroke="white"/>
<circle cx="650" cy="316" r="1" transform="rotate(-180 650 316)" fill="black" fill-opacity="0.4"/>
<circle cx="620" cy="316" r="7.5" transform="rotate(-180 620 316)" fill="url(#paint85_radial_781_319508)" stroke="white"/>
<circle cx="620" cy="316" r="1" transform="rotate(-180 620 316)" fill="black" fill-opacity="0.4"/>
<circle cx="590" cy="316" r="7.5" transform="rotate(-180 590 316)" fill="url(#paint86_radial_781_319508)" stroke="white"/>
<circle cx="590" cy="316" r="1" transform="rotate(-180 590 316)" fill="black" fill-opacity="0.4"/>
<circle cx="530" cy="376" r="7.5" transform="rotate(-180 530 376)" fill="url(#paint87_radial_781_319508)" stroke="white"/>
<circle cx="530" cy="376" r="1" transform="rotate(-180 530 376)" fill="black" fill-opacity="0.4"/>
<rect x="470" y="437" width="120" height="2" transform="rotate(-180 470 437)" fill="#FF9D52"/>
<rect x="470" y="377" width="120" height="2" transform="rotate(-180 470 377)" fill="#FF9D52"/>
<rect x="470" y="317" width="120" height="2" transform="rotate(-180 470 317)" fill="#FF9D52"/>
<rect x="411" y="438" width="2" height="120" transform="rotate(-180 411 438)" fill="#4196D8"/>
<rect x="349.195" y="378" width="2" height="80" transform="rotate(-135 349.195 378)" fill="#D34F4F"/>
<rect x="405.766" y="430.586" width="2" height="80" transform="rotate(135 405.766 430.586)" fill="#D34F4F"/>
<rect x="470.766" y="375" width="2" height="80" transform="rotate(135 470.766 375)" fill="#D34F4F"/>
<rect x="411.211" y="436.414" width="2" height="80" transform="rotate(-135 411.211 436.414)" fill="#D34F4F"/>
<rect x="471" y="438" width="2" height="120" transform="rotate(-180 471 438)" fill="#FF9D52"/>
<rect x="351" y="438" width="2" height="120" transform="rotate(-180 351 438)" fill="#FF9D52"/>
<circle cx="470" cy="436" r="7.5" transform="rotate(-180 470 436)" fill="url(#paint88_radial_781_319508)" stroke="white"/>
<circle cx="470" cy="436" r="1" transform="rotate(-180 470 436)" fill="black" fill-opacity="0.4"/>
<circle cx="440" cy="436" r="7.5" transform="rotate(-180 440 436)" fill="url(#paint89_radial_781_319508)" stroke="white"/>
<circle cx="440" cy="436" r="1" transform="rotate(-180 440 436)" fill="black" fill-opacity="0.4"/>
<circle cx="410" cy="436" r="7.5" transform="rotate(-180 410 436)" fill="url(#paint90_radial_781_319508)" stroke="white"/>
<circle cx="410" cy="436" r="1" transform="rotate(-180 410 436)" fill="black" fill-opacity="0.4"/>
<circle cx="380" cy="436" r="7.5" transform="rotate(-180 380 436)" fill="url(#paint91_radial_781_319508)" stroke="white"/>
<circle cx="380" cy="436" r="1" transform="rotate(-180 380 436)" fill="black" fill-opacity="0.4"/>
<circle cx="350" cy="436" r="7.5" transform="rotate(-180 350 436)" fill="url(#paint92_radial_781_319508)" stroke="white"/>
<circle cx="350" cy="436" r="1" transform="rotate(-180 350 436)" fill="black" fill-opacity="0.4"/>
<circle cx="440" cy="376" r="7.5" transform="rotate(-180 440 376)" fill="url(#paint93_radial_781_319508)" stroke="white"/>
<circle cx="440" cy="376" r="1" transform="rotate(-180 440 376)" fill="black" fill-opacity="0.4"/>
<circle cx="380" cy="376" r="7.5" transform="rotate(-180 380 376)" fill="url(#paint94_radial_781_319508)" stroke="white"/>
<circle cx="380" cy="376" r="1" transform="rotate(-180 380 376)" fill="black" fill-opacity="0.4"/>
<circle cx="410" cy="376" r="7.5" transform="rotate(-180 410 376)" fill="url(#paint95_radial_781_319508)" stroke="white"/>
<circle cx="410" cy="376" r="1" transform="rotate(-180 410 376)" fill="black" fill-opacity="0.4"/>
<circle cx="350" cy="406" r="7.5" transform="rotate(-180 350 406)" fill="url(#paint96_radial_781_319508)" stroke="white"/>
<circle cx="350" cy="406" r="1" transform="rotate(-180 350 406)" fill="black" fill-opacity="0.4"/>
<circle cx="470" cy="406" r="7.5" transform="rotate(-180 470 406)" fill="url(#paint97_radial_781_319508)" stroke="white"/>
<circle cx="470" cy="406" r="1" transform="rotate(-180 470 406)" fill="black" fill-opacity="0.4"/>
<circle cx="350" cy="376" r="7.5" transform="rotate(-180 350 376)" fill="url(#paint98_radial_781_319508)" stroke="white"/>
<circle cx="350" cy="376" r="1" transform="rotate(-180 350 376)" fill="black" fill-opacity="0.4"/>
<circle cx="470" cy="376" r="7.5" transform="rotate(-180 470 376)" fill="url(#paint99_radial_781_319508)" stroke="white"/>
<circle cx="470" cy="376" r="1" transform="rotate(-180 470 376)" fill="black" fill-opacity="0.4"/>
<circle cx="350" cy="346" r="7.5" transform="rotate(-180 350 346)" fill="url(#paint100_radial_781_319508)" stroke="white"/>
<circle cx="350" cy="346" r="1" transform="rotate(-180 350 346)" fill="black" fill-opacity="0.4"/>
<circle cx="470" cy="346" r="7.5" transform="rotate(-180 470 346)" fill="url(#paint101_radial_781_319508)" stroke="white"/>
<circle cx="470" cy="346" r="1" transform="rotate(-180 470 346)" fill="black" fill-opacity="0.4"/>
<circle cx="470" cy="316" r="7.5" transform="rotate(-180 470 316)" fill="url(#paint102_radial_781_319508)" stroke="white"/>
<circle cx="470" cy="316" r="1" transform="rotate(-180 470 316)" fill="black" fill-opacity="0.4"/>
<circle cx="440" cy="316" r="7.5" transform="rotate(-180 440 316)" fill="url(#paint103_radial_781_319508)" stroke="white"/>
<circle cx="440" cy="316" r="1" transform="rotate(-180 440 316)" fill="black" fill-opacity="0.4"/>
<circle cx="410" cy="316" r="7.5" transform="rotate(-180 410 316)" fill="url(#paint104_radial_781_319508)" stroke="white"/>
<circle cx="410" cy="316" r="1" transform="rotate(-180 410 316)" fill="black" fill-opacity="0.4"/>
<circle cx="380" cy="316" r="7.5" transform="rotate(-180 380 316)" fill="url(#paint105_radial_781_319508)" stroke="white"/>
<circle cx="380" cy="316" r="1" transform="rotate(-180 380 316)" fill="black" fill-opacity="0.4"/>
<circle cx="350" cy="316" r="7.5" transform="rotate(-180 350 316)" fill="url(#paint106_radial_781_319508)" stroke="white"/>
<circle cx="350" cy="316" r="1" transform="rotate(-180 350 316)" fill="black" fill-opacity="0.4"/>
<circle cx="290" cy="376" r="7.5" transform="rotate(-180 290 376)" fill="url(#paint107_radial_781_319508)" stroke="white"/>
<circle cx="290" cy="376" r="1" transform="rotate(-180 290 376)" fill="black" fill-opacity="0.4"/>
<rect x="230" y="437" width="120" height="2" transform="rotate(-180 230 437)" fill="#FF9D52"/>
<rect x="230" y="377" width="120" height="2" transform="rotate(-180 230 377)" fill="#FF9D52"/>
<rect x="230" y="317" width="120" height="2" transform="rotate(-180 230 317)" fill="#FF9D52"/>
<rect x="171" y="438" width="2" height="120" transform="rotate(-180 171 438)" fill="#4196D8"/>
<rect x="109.195" y="378" width="2" height="80" transform="rotate(-135 109.195 378)" fill="#D34F4F"/>
<rect x="165.766" y="430.586" width="2" height="80" transform="rotate(135 165.766 430.586)" fill="#D34F4F"/>
<rect x="230.766" y="375" width="2" height="80" transform="rotate(135 230.766 375)" fill="#D34F4F"/>
<rect x="173.914" y="433.414" width="2" height="80" transform="rotate(-135 173.914 433.414)" fill="#D34F4F"/>
<rect x="231" y="438" width="2" height="120" transform="rotate(-180 231 438)" fill="#FF9D52"/>
<rect x="111" y="438" width="2" height="120" transform="rotate(-180 111 438)" fill="#FF9D52"/>
<circle cx="230" cy="436" r="7.5" transform="rotate(-180 230 436)" fill="url(#paint108_radial_781_319508)" stroke="white"/>
<circle cx="230" cy="436" r="1" transform="rotate(-180 230 436)" fill="black" fill-opacity="0.4"/>
<circle cx="200" cy="436" r="7.5" transform="rotate(-180 200 436)" fill="url(#paint109_radial_781_319508)" stroke="white"/>
<circle cx="200" cy="436" r="1" transform="rotate(-180 200 436)" fill="black" fill-opacity="0.4"/>
<circle cx="170" cy="436" r="7.5" transform="rotate(-180 170 436)" fill="url(#paint110_radial_781_319508)" stroke="white"/>
<circle cx="170" cy="436" r="1" transform="rotate(-180 170 436)" fill="black" fill-opacity="0.4"/>
<circle cx="140" cy="436" r="7.5" transform="rotate(-180 140 436)" fill="url(#paint111_radial_781_319508)" stroke="white"/>
<circle cx="140" cy="436" r="1" transform="rotate(-180 140 436)" fill="black" fill-opacity="0.4"/>
<circle cx="110" cy="436" r="7.5" transform="rotate(-180 110 436)" fill="url(#paint112_radial_781_319508)" stroke="white"/>
<circle cx="110" cy="436" r="1" transform="rotate(-180 110 436)" fill="black" fill-opacity="0.4"/>
<circle cx="200" cy="376" r="7.5" transform="rotate(-180 200 376)" fill="url(#paint113_radial_781_319508)" stroke="white"/>
<circle cx="200" cy="376" r="1" transform="rotate(-180 200 376)" fill="black" fill-opacity="0.4"/>
<circle cx="140" cy="376" r="7.5" transform="rotate(-180 140 376)" fill="url(#paint114_radial_781_319508)" stroke="white"/>
<circle cx="140" cy="376" r="1" transform="rotate(-180 140 376)" fill="black" fill-opacity="0.4"/>
<circle cx="170" cy="376" r="7.5" transform="rotate(-180 170 376)" fill="url(#paint115_radial_781_319508)" stroke="white"/>
<circle cx="170" cy="376" r="1" transform="rotate(-180 170 376)" fill="black" fill-opacity="0.4"/>
<circle cx="110" cy="406" r="7.5" transform="rotate(-180 110 406)" fill="url(#paint116_radial_781_319508)" stroke="white"/>
<circle cx="110" cy="406" r="1" transform="rotate(-180 110 406)" fill="black" fill-opacity="0.4"/>
<circle cx="230" cy="406" r="7.5" transform="rotate(-180 230 406)" fill="url(#paint117_radial_781_319508)" stroke="white"/>
<circle cx="230" cy="406" r="1" transform="rotate(-180 230 406)" fill="black" fill-opacity="0.4"/>
<circle cx="110" cy="376" r="7.5" transform="rotate(-180 110 376)" fill="url(#paint118_radial_781_319508)" stroke="white"/>
<circle cx="110" cy="376" r="1" transform="rotate(-180 110 376)" fill="black" fill-opacity="0.4"/>
<circle cx="230" cy="376" r="7.5" transform="rotate(-180 230 376)" fill="url(#paint119_radial_781_319508)" stroke="white"/>
<circle cx="230" cy="376" r="1" transform="rotate(-180 230 376)" fill="black" fill-opacity="0.4"/>
<circle cx="110" cy="346" r="7.5" transform="rotate(-180 110 346)" fill="url(#paint120_radial_781_319508)" stroke="white"/>
<circle cx="110" cy="346" r="1" transform="rotate(-180 110 346)" fill="black" fill-opacity="0.4"/>
<circle cx="230" cy="346" r="7.5" transform="rotate(-180 230 346)" fill="url(#paint121_radial_781_319508)" stroke="white"/>
<circle cx="230" cy="346" r="1" transform="rotate(-180 230 346)" fill="black" fill-opacity="0.4"/>
<circle cx="230" cy="316" r="7.5" transform="rotate(-180 230 316)" fill="url(#paint122_radial_781_319508)" stroke="white"/>
<circle cx="230" cy="316" r="1" transform="rotate(-180 230 316)" fill="black" fill-opacity="0.4"/>
<circle cx="200" cy="316" r="7.5" transform="rotate(-180 200 316)" fill="url(#paint123_radial_781_319508)" stroke="white"/>
<circle cx="200" cy="316" r="1" transform="rotate(-180 200 316)" fill="black" fill-opacity="0.4"/>
<circle cx="170" cy="316" r="7.5" transform="rotate(-180 170 316)" fill="url(#paint124_radial_781_319508)" stroke="white"/>
<circle cx="170" cy="316" r="1" transform="rotate(-180 170 316)" fill="black" fill-opacity="0.4"/>
<circle cx="140" cy="316" r="7.5" transform="rotate(-180 140 316)" fill="url(#paint125_radial_781_319508)" stroke="white"/>
<circle cx="140" cy="316" r="1" transform="rotate(-180 140 316)" fill="black" fill-opacity="0.4"/>
<circle cx="110" cy="316" r="7.5" transform="rotate(-180 110 316)" fill="url(#paint126_radial_781_319508)" stroke="white"/>
<circle cx="110" cy="316" r="1" transform="rotate(-180 110 316)" fill="black" fill-opacity="0.4"/>
<circle cx="680" cy="466" r="7.5" transform="rotate(-180 680 466)" fill="url(#paint127_radial_781_319508)" stroke="white"/>
<circle cx="680" cy="466" r="1" transform="rotate(-180 680 466)" fill="black" fill-opacity="0.4"/>
<circle cx="440" cy="466" r="7.5" transform="rotate(-180 440 466)" fill="url(#paint128_radial_781_319508)" stroke="white"/>
<circle cx="440" cy="466" r="1" transform="rotate(-180 440 466)" fill="black" fill-opacity="0.4"/>
<circle cx="200" cy="466" r="7.5" transform="rotate(-180 200 466)" fill="url(#paint129_radial_781_319508)" stroke="white"/>
<circle cx="200" cy="466" r="1" transform="rotate(-180 200 466)" fill="black" fill-opacity="0.4"/>
<circle cx="620" cy="466" r="7.5" transform="rotate(-180 620 466)" fill="url(#paint130_radial_781_319508)" stroke="white"/>
<circle cx="620" cy="466" r="1" transform="rotate(-180 620 466)" fill="black" fill-opacity="0.4"/>
<circle cx="380" cy="466" r="7.5" transform="rotate(-180 380 466)" fill="url(#paint131_radial_781_319508)" stroke="white"/>
<circle cx="380" cy="466" r="1" transform="rotate(-180 380 466)" fill="black" fill-opacity="0.4"/>
<circle cx="140" cy="466" r="7.5" transform="rotate(-180 140 466)" fill="url(#paint132_radial_781_319508)" stroke="white"/>
<circle cx="140" cy="466" r="1" transform="rotate(-180 140 466)" fill="black" fill-opacity="0.4"/>
<circle cx="590" cy="196" r="7.5" fill="url(#paint133_radial_781_319508)" stroke="white"/>
<circle cx="590" cy="196" r="1" fill="black" fill-opacity="0.4"/>
<circle cx="620" cy="196" r="7.5" fill="url(#paint134_radial_781_319508)" stroke="white"/>
<circle cx="620" cy="196" r="1" fill="black" fill-opacity="0.4"/>
<circle cx="650" cy="196" r="7.5" fill="url(#paint135_radial_781_319508)" stroke="white"/>
<circle cx="650" cy="196" r="1" fill="black" fill-opacity="0.4"/>
<circle cx="680" cy="196" r="7.5" fill="url(#paint136_radial_781_319508)" stroke="white"/>
<circle cx="680" cy="196" r="1" fill="black" fill-opacity="0.4"/>
<circle cx="710" cy="196" r="7.5" fill="url(#paint137_radial_781_319508)" stroke="white"/>
<circle cx="710" cy="196" r="1" fill="black" fill-opacity="0.4"/>
<defs>
<radialGradient id="paint0_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(110 76) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint1_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(140 76) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint2_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(170 76) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint3_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(200 76) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint4_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(230 76) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint5_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(140 136) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint6_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(200 136) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint7_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(170 136) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint8_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(230 106) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint9_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(110 106) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint10_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(230 136) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint11_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(110 136) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint12_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(230 166) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint13_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(110 166) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint14_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(290 136) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint15_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(290 256) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint16_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(530 256) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint17_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(350 76) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint18_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(380 76) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint19_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(410 76) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint20_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(440 76) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint21_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(470 76) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint22_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(380 136) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint23_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(440 136) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint24_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(410 136) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint25_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(470 106) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint26_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(350 106) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint27_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(470 136) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint28_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(350 136) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint29_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(470 166) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint30_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(350 166) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint31_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(350 196) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint32_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(380 196) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint33_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(410 196) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint34_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(440 196) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint35_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(470 196) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint36_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(530 136) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint37_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(590 76) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint38_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(620 76) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint39_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(650 76) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint40_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(680 76) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint41_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(710 76) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint42_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(620 136) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint43_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(680 136) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint44_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(650 136) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint45_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(710 106) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint46_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(590 106) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint47_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(710 136) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint48_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(590 136) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint49_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(710 166) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint50_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(590 166) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint51_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(60 136) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint52_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(60 376) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint53_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(60 256) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint54_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(110 196) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint55_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(140 196) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint56_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(170 196) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint57_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(200 196) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint58_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(230 196) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint59_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(760.461 375.379) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint60_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(760.461 255.379) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint61_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(760.461 135.383) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint62_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(140 46) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint63_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(380 46) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint64_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(620 46) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint65_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(200 46) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint66_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(440 46) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint67_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(680 46) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint68_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(710 436) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint69_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(680 436) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint70_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(650 436) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint71_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(620 436) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint72_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(590 436) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint73_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(680 376) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint74_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(620 376) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint75_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(650 376) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint76_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(590 406) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint77_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(710 406) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint78_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(590 376) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint79_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(710 376) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint80_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(590 346) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint81_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(710 346) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint82_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(710 316) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint83_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(680 316) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint84_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(650 316) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint85_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(620 316) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint86_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(590 316) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint87_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(530 376) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint88_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(470 436) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint89_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(440 436) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint90_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(410 436) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint91_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(380 436) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint92_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(350 436) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint93_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(440 376) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint94_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(380 376) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint95_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(410 376) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint96_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(350 406) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint97_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(470 406) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint98_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(350 376) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint99_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(470 376) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint100_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(350 346) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint101_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(470 346) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint102_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(470 316) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint103_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(440 316) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint104_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(410 316) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint105_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(380 316) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint106_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(350 316) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint107_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(290 376) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint108_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(230 436) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint109_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(200 436) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint110_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(170 436) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint111_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(140 436) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint112_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(110 436) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint113_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(200 376) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint114_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(140 376) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint115_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(170 376) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint116_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(110 406) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint117_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(230 406) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint118_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(110 376) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint119_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(230 376) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint120_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(110 346) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint121_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(230 346) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint122_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(230 316) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint123_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(200 316) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint124_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(170 316) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint125_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(140 316) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint126_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(110 316) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint127_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(680 466) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint128_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(440 466) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint129_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(200 466) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint130_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(620 466) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint131_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(380 466) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint132_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(140 466) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint133_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(590 196) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint134_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(620 196) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint135_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(650 196) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint136_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(680 196) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
<radialGradient id="paint137_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(710 196) rotate(90) scale(7.5)">
<stop stop-color="#B39351"/>
<stop offset="1" stop-color="#FFECC4"/>
</radialGradient>
</defs>
</svg>

After

Width:  |  Height:  |  Size: 71 KiB

4
examples/.keep Normal file
View file

@ -0,0 +1,4 @@
File generated from our OpenAPI spec by Stainless.
This directory can be used to store example files demonstrating usage of this SDK.
It is ignored by Stainless code generation and its content (other than this keep file) won't be touched.

50
mypy.ini Normal file
View file

@ -0,0 +1,50 @@
[mypy]
pretty = True
show_error_codes = True
# Exclude _files.py because mypy isn't smart enough to apply
# the correct type narrowing and as this is an internal module
# it's fine to just use Pyright.
#
# We also exclude our `tests` as mypy doesn't always infer
# types correctly and Pyright will still catch any type errors.
exclude = ^(src/tinker/_files\.py|_dev/.*\.py|tests/.*)$
strict_equality = True
implicit_reexport = True
check_untyped_defs = True
no_implicit_optional = True
warn_return_any = True
warn_unreachable = True
warn_unused_configs = True
# Turn these options off as it could cause conflicts
# with the Pyright options.
warn_unused_ignores = False
warn_redundant_casts = False
disallow_any_generics = True
disallow_untyped_defs = True
disallow_untyped_calls = True
disallow_subclassing_any = True
disallow_incomplete_defs = True
disallow_untyped_decorators = True
cache_fine_grained = True
# By default, mypy reports an error if you assign a value to the result
# of a function call that doesn't return anything. We do this in our test
# cases:
# ```
# result = ...
# assert result is None
# ```
# Changing this codegen to make mypy happy would increase complexity
# and would not be worth it.
disable_error_code = func-returns-value,overload-cannot-match
# https://github.com/python/mypy/issues/12162
[mypy.overrides]
module = "black.files.*"
ignore_errors = true
ignore_missing_imports = true

185
pyproject.toml Normal file
View file

@ -0,0 +1,185 @@
[project]
name = "tinker"
version = "0.1.3"
description = "The official Python SDK for the tinker API"
readme = "README.md"
license = "Apache-2.0"
authors = [
{ name = "Tinker authors", email = "tinker@thinkingmachines.ai" },
]
keywords = ["tinker", "machine learning"]
dependencies = [
"httpx[http2]>=0.23.0, <1",
"pydantic>=1.9.0, <3",
"typing-extensions>=4.10, <5",
"anyio>=3.5.0, <5",
"distro>=1.7.0, <2",
"sniffio",
"numpy",
"torch",
]
requires-python = ">= 3.11"
classifiers = [
"Typing :: Typed",
"Intended Audience :: Developers",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
"Operating System :: OS Independent",
"Operating System :: POSIX",
"Operating System :: MacOS",
"Operating System :: POSIX :: Linux",
"Operating System :: Microsoft :: Windows",
"Topic :: Software Development :: Libraries :: Python Modules",
"License :: OSI Approved :: Apache Software License"
]
[project.urls]
Homepage = "https://thinkingmachines.ai/tinker"
Repository = "https://github.com/thinking-machines-lab/tinker"
Documentation = "https://tinker-docs.thinkingmachines.ai/"
[project.optional-dependencies]
aiohttp = ["aiohttp", "httpx_aiohttp>=0.1.8"]
[tool.uv]
managed = true
required-version = ">=0.5.0"
# version pins are in uv.lock
dev-dependencies = [
"pyright==1.1.402",
"mypy",
"respx",
"pytest",
"pytest-asyncio",
"ruff",
"time-machine",
"dirty-equals>=0.6.0",
"importlib-metadata>=6.7.0",
"rich>=13.7.1",
"nest_asyncio==1.6.0",
"pytest-xdist>=3.6.1",
]
[build-system]
requires = ["hatchling==1.26.3", "hatch-fancy-pypi-readme"]
build-backend = "hatchling.build"
[tool.hatch.build]
include = [
"src/*"
]
[tool.hatch.build.targets.wheel]
packages = ["src/tinker"]
[tool.hatch.build.targets.sdist]
# Basically everything except hidden files/directories (such as .github, .devcontainers, .python-version, etc)
include = [
"/*.toml",
"/*.json",
"/*.lock",
"/*.md",
"/mypy.ini",
"/noxfile.py",
"bin/*",
"examples/*",
"src/*",
"tests/*",
]
[tool.hatch.metadata.hooks.fancy-pypi-readme]
content-type = "text/markdown"
[[tool.hatch.metadata.hooks.fancy-pypi-readme.fragments]]
path = "README.md"
[[tool.hatch.metadata.hooks.fancy-pypi-readme.substitutions]]
# replace relative links with absolute links
pattern = '\[(.+?)\]\(((?!https?://)\S+?)\)'
replacement = '[\1](https://github.com/stainless-sdks/tinker-python/tree/main/\g<2>)'
[tool.pytest.ini_options]
testpaths = ["tests"]
addopts = "--tb=short -n auto"
xfail_strict = true
asyncio_mode = "auto"
asyncio_default_fixture_loop_scope = "session"
filterwarnings = [
"error"
]
[tool.pyright]
# this enables practically every flag given by pyright.
# there are a couple of flags that are still disabled by
# default in strict mode as they are experimental and niche.
typeCheckingMode = "strict"
pythonVersion = "3.8"
exclude = [
"_dev",
".venv",
".nox",
]
reportImplicitOverride = true
reportOverlappingOverload = false
reportImportCycles = false
reportPrivateUsage = false
[tool.ruff]
line-length = 120
output-format = "grouped"
target-version = "py38"
[tool.ruff.format]
docstring-code-format = true
[tool.ruff.lint]
select = [
# isort
"I",
# bugbear rules
"B",
# remove unused imports
"F401",
# bare except statements
"E722",
# unused arguments
"ARG",
# print statements
"T201",
"T203",
# misuse of typing.TYPE_CHECKING
"TC004",
# import rules
"TID251",
]
ignore = [
# mutable defaults
"B006",
]
unfixable = [
# disable auto fix for print statements
"T201",
"T203",
]
[tool.ruff.lint.flake8-tidy-imports.banned-api]
"functools.lru_cache".msg = "This function does not retain type information for the wrapped function's arguments; The `lru_cache` function from `_utils` should be used instead"
[tool.ruff.lint.isort]
length-sort = true
length-sort-straight = true
combine-as-imports = true
extra-standard-library = ["typing_extensions"]
known-first-party = ["tinker", "tests"]
[tool.ruff.lint.per-file-ignores]
"bin/**.py" = ["T201", "T203"]
"scripts/**.py" = ["T201", "T203"]
"tests/**.py" = ["T201", "T203"]
"examples/**.py" = ["T201", "T203"]

105
requirements-dev.lock Normal file
View file

@ -0,0 +1,105 @@
# This file was autogenerated by uv via the following command:
# uv export -o requirements-dev.lock --no-hashes
-e .
annotated-types==0.7.0
# via pydantic
anyio==4.5.2 ; python_full_version < '3.9'
# via
# httpx
# tinker
anyio==4.8.0 ; python_full_version >= '3.9'
# via
# httpx
# tinker
certifi==2024.12.14
# via
# httpcore
# httpx
colorama==0.4.6 ; sys_platform == 'win32'
# via pytest
dirty-equals==0.9.0
distro==1.9.0
# via tinker
exceptiongroup==1.2.2 ; python_full_version < '3.11'
# via
# anyio
# pytest
execnet==2.1.1
# via pytest-xdist
h11==0.16.0
# via httpcore
httpcore==1.0.9
# via httpx
httpx==0.28.1
# via
# respx
# tinker
idna==3.10
# via
# anyio
# httpx
importlib-metadata==8.5.0 ; python_full_version < '3.9'
importlib-metadata==8.6.1 ; python_full_version >= '3.9'
iniconfig==2.0.0
# via pytest
markdown-it-py==3.0.0
# via rich
mdurl==0.1.2
# via markdown-it-py
mypy==1.14.1
mypy-extensions==1.0.0
# via mypy
nest-asyncio==1.6.0
nodeenv==1.9.1
# via pyright
packaging==24.2
# via pytest
pluggy==1.5.0
# via pytest
pydantic==2.10.3
# via tinker
pydantic-core==2.27.1
# via pydantic
pygments==2.19.1
# via rich
pyright==1.1.399
pytest==8.3.3
# via
# pytest-asyncio
# pytest-xdist
pytest-asyncio==0.24.0
pytest-xdist==3.6.1 ; python_full_version < '3.9'
pytest-xdist==3.7.0 ; python_full_version >= '3.9'
python-dateutil==2.9.0.post0
# via time-machine
pytz==2024.2 ; python_full_version < '3.9'
# via dirty-equals
respx==0.22.0
rich==13.9.4
ruff==0.9.4
six==1.17.0
# via python-dateutil
sniffio==1.3.1
# via
# anyio
# tinker
time-machine==2.15.0 ; python_full_version < '3.9'
time-machine==2.16.0 ; python_full_version >= '3.9'
tomli==2.2.1 ; python_full_version < '3.11'
# via
# mypy
# pytest
typing-extensions==4.12.2
# via
# annotated-types
# anyio
# mypy
# pydantic
# pydantic-core
# pyright
# rich
# tinker
zipp==3.20.2 ; python_full_version < '3.9'
# via importlib-metadata
zipp==3.21.0 ; python_full_version >= '3.9'
# via importlib-metadata

22
scripts/bootstrap Executable file
View file

@ -0,0 +1,22 @@
#!/usr/bin/env bash
set -e
cd "$(dirname "$0")/.."
if [ -f "Brewfile" ] && [ "$(uname -s)" = "Darwin" ] && [ "$SKIP_BREW" != "1" ]; then
brew bundle check >/dev/null 2>&1 || {
echo "==> Installing Homebrew dependencies…"
brew bundle
}
fi
echo "==> Installing Python…"
uv python install
echo "==> Installing Python dependencies…"
uv sync --all-extras
echo "==> Exporting Python dependencies…"
# note: `--no-hashes` is required because of https://github.com/pypa/pip/issues/4995
uv export -o requirements-dev.lock --no-hashes

14
scripts/format Executable file
View file

@ -0,0 +1,14 @@
#!/usr/bin/env bash
set -e
cd "$(dirname "$0")/.."
echo "==> Running ruff"
uv run ruff format
uv run ruff check --fix .
# run formatting again to fix any inconsistencies when imports are stripped
uv run ruff format
echo "==> Formatting docs"
uv run python scripts/utils/ruffen-docs.py README.md api.md

17
scripts/lint Executable file
View file

@ -0,0 +1,17 @@
#!/usr/bin/env bash
set -e
cd "$(dirname "$0")/.."
echo "==> Running ruff"
uv run ruff check .
echo "==> Running pyright"
uv run pyright
echo "==> Running mypy"
uv run mypy .
echo "==> Making sure it imports"
uv run python -c 'import tinker'

41
scripts/mock Executable file
View file

@ -0,0 +1,41 @@
#!/usr/bin/env bash
set -e
cd "$(dirname "$0")/.."
if [[ -n "$1" && "$1" != '--'* ]]; then
URL="$1"
shift
else
URL="$(grep 'openapi_spec_url' .stats.yml | cut -d' ' -f2)"
fi
# Check if the URL is empty
if [ -z "$URL" ]; then
echo "Error: No OpenAPI spec path/url provided or found in .stats.yml"
exit 1
fi
echo "==> Starting mock server with URL ${URL}"
# Run prism mock on the given spec
if [ "$1" == "--daemon" ]; then
npm exec --package=@stainless-api/prism-cli@5.15.0 -- prism mock "$URL" &>.prism.log &
# Wait for server to come online
echo -n "Waiting for server"
while ! grep -q "✖ fatal\|Prism is listening" ".prism.log"; do
echo -n "."
sleep 0.1
done
if grep -q "✖ fatal" ".prism.log"; then
cat .prism.log
exit 1
fi
echo
else
npm exec --package=@stainless-api/prism-cli@5.15.0 -- prism mock "$URL"
fi

77
scripts/test Executable file
View file

@ -0,0 +1,77 @@
#!/usr/bin/env bash
set -e
cd "$(dirname "$0")/.."
RED='\033[0;31m'
GREEN='\033[0;32m'
YELLOW='\033[0;33m'
NC='\033[0m' # No Color
function prism_is_running() {
curl --silent "http://localhost:4010" >/dev/null 2>&1
}
kill_server_on_port() {
pids=$(lsof -t -i tcp:"$1" || echo "")
if [ "$pids" != "" ]; then
kill "$pids"
echo "Stopped $pids."
fi
}
function is_overriding_api_base_url() {
[ -n "$TEST_API_BASE_URL" ]
}
if ! is_overriding_api_base_url && ! prism_is_running; then
# When we exit this script, make sure to kill the background mock server process
trap 'kill_server_on_port 4010' EXIT
# Start the dev server
./scripts/mock --daemon
fi
if is_overriding_api_base_url; then
echo -e "${GREEN}✔ Running tests against ${TEST_API_BASE_URL}${NC}"
echo
elif ! prism_is_running; then
echo -e "${RED}ERROR:${NC} The test suite will not run without a mock Prism server"
echo -e "running against your OpenAPI spec."
echo
echo -e "To run the server, pass in the path or url of your OpenAPI"
echo -e "spec to the prism command:"
echo
echo -e " \$ ${YELLOW}npm exec --package=@stainless-api/prism-cli@5.15.0 -- prism mock path/to/your.openapi.yml${NC}"
echo
exit 1
else
echo -e "${GREEN}✔ Mock prism server is running with your OpenAPI spec${NC}"
echo
fi
export DEFER_PYDANTIC_BUILD=false
function run_tests() {
echo "==> Running tests with Pydantic v2"
uv run --all-extras --all-groups pytest "$@"
echo "==> Running tests with Pydantic v1"
uv pip install 'pydantic<2'
uv run --all-extras --all-groups pytest "$@"
}
# If UV_PYTHON is already set in the environment, just run the command once
if [[ -n "$UV_PYTHON" ]]; then
run_tests "$@"
else
# If UV_PYTHON is not set, run the command for min and max versions
echo "==> Running tests for Python 3.9"
UV_PYTHON=3.9 run_tests "$@"
echo "==> Running tests for Python 3.13"
UV_PYTHON=3.13 run_tests "$@"
fi

View file

@ -0,0 +1,167 @@
# fork of https://github.com/asottile/blacken-docs adapted for ruff
from __future__ import annotations
import re
import sys
import argparse
import textwrap
import contextlib
import subprocess
from typing import Match, Optional, Sequence, Generator, NamedTuple, cast
MD_RE = re.compile(
r"(?P<before>^(?P<indent> *)```\s*python\n)" r"(?P<code>.*?)" r"(?P<after>^(?P=indent)```\s*$)",
re.DOTALL | re.MULTILINE,
)
MD_PYCON_RE = re.compile(
r"(?P<before>^(?P<indent> *)```\s*pycon\n)" r"(?P<code>.*?)" r"(?P<after>^(?P=indent)```.*$)",
re.DOTALL | re.MULTILINE,
)
PYCON_PREFIX = ">>> "
PYCON_CONTINUATION_PREFIX = "..."
PYCON_CONTINUATION_RE = re.compile(
rf"^{re.escape(PYCON_CONTINUATION_PREFIX)}( |$)",
)
DEFAULT_LINE_LENGTH = 100
class CodeBlockError(NamedTuple):
offset: int
exc: Exception
def format_str(
src: str,
) -> tuple[str, Sequence[CodeBlockError]]:
errors: list[CodeBlockError] = []
@contextlib.contextmanager
def _collect_error(match: Match[str]) -> Generator[None, None, None]:
try:
yield
except Exception as e:
errors.append(CodeBlockError(match.start(), e))
def _md_match(match: Match[str]) -> str:
code = textwrap.dedent(match["code"])
with _collect_error(match):
code = format_code_block(code)
code = textwrap.indent(code, match["indent"])
return f"{match['before']}{code}{match['after']}"
def _pycon_match(match: Match[str]) -> str:
code = ""
fragment = cast(Optional[str], None)
def finish_fragment() -> None:
nonlocal code
nonlocal fragment
if fragment is not None:
with _collect_error(match):
fragment = format_code_block(fragment)
fragment_lines = fragment.splitlines()
code += f"{PYCON_PREFIX}{fragment_lines[0]}\n"
for line in fragment_lines[1:]:
# Skip blank lines to handle Black adding a blank above
# functions within blocks. A blank line would end the REPL
# continuation prompt.
#
# >>> if True:
# ... def f():
# ... pass
# ...
if line:
code += f"{PYCON_CONTINUATION_PREFIX} {line}\n"
if fragment_lines[-1].startswith(" "):
code += f"{PYCON_CONTINUATION_PREFIX}\n"
fragment = None
indentation = None
for line in match["code"].splitlines():
orig_line, line = line, line.lstrip()
if indentation is None and line:
indentation = len(orig_line) - len(line)
continuation_match = PYCON_CONTINUATION_RE.match(line)
if continuation_match and fragment is not None:
fragment += line[continuation_match.end() :] + "\n"
else:
finish_fragment()
if line.startswith(PYCON_PREFIX):
fragment = line[len(PYCON_PREFIX) :] + "\n"
else:
code += orig_line[indentation:] + "\n"
finish_fragment()
return code
def _md_pycon_match(match: Match[str]) -> str:
code = _pycon_match(match)
code = textwrap.indent(code, match["indent"])
return f"{match['before']}{code}{match['after']}"
src = MD_RE.sub(_md_match, src)
src = MD_PYCON_RE.sub(_md_pycon_match, src)
return src, errors
def format_code_block(code: str) -> str:
return subprocess.check_output(
[
sys.executable,
"-m",
"ruff",
"format",
"--stdin-filename=script.py",
f"--line-length={DEFAULT_LINE_LENGTH}",
],
encoding="utf-8",
input=code,
)
def format_file(
filename: str,
skip_errors: bool,
) -> int:
with open(filename, encoding="UTF-8") as f:
contents = f.read()
new_contents, errors = format_str(contents)
for error in errors:
lineno = contents[: error.offset].count("\n") + 1
print(f"{filename}:{lineno}: code block parse error {error.exc}")
if errors and not skip_errors:
return 1
if contents != new_contents:
print(f"{filename}: Rewriting...")
with open(filename, "w", encoding="UTF-8") as f:
f.write(new_contents)
return 0
else:
return 0
def main(argv: Sequence[str] | None = None) -> int:
parser = argparse.ArgumentParser()
parser.add_argument(
"-l",
"--line-length",
type=int,
default=DEFAULT_LINE_LENGTH,
)
parser.add_argument(
"-S",
"--skip-string-normalization",
action="store_true",
)
parser.add_argument("-E", "--skip-errors", action="store_true")
parser.add_argument("filenames", nargs="*")
args = parser.parse_args(argv)
retv = 0
for filename in args.filenames:
retv |= format_file(filename, skip_errors=args.skip_errors)
return retv
if __name__ == "__main__":
raise SystemExit(main())

View file

@ -0,0 +1,27 @@
#!/usr/bin/env bash
set -exuo pipefail
FILENAME=$(basename dist/*.whl)
RESPONSE=$(curl -X POST "$URL?filename=$FILENAME" \
-H "Authorization: Bearer $AUTH" \
-H "Content-Type: application/json")
SIGNED_URL=$(echo "$RESPONSE" | jq -r '.url')
if [[ "$SIGNED_URL" == "null" ]]; then
echo -e "\033[31mFailed to get signed URL.\033[0m"
exit 1
fi
UPLOAD_RESPONSE=$(curl -v -X PUT \
-H "Content-Type: binary/octet-stream" \
--data-binary "@dist/$FILENAME" "$SIGNED_URL" 2>&1)
if echo "$UPLOAD_RESPONSE" | grep -q "HTTP/[0-9.]* 200"; then
echo -e "\033[32mUploaded build to Stainless storage.\033[0m"
echo -e "\033[32mInstallation: pip install 'https://pkg.stainless.com/s/tinker-python/$SHA/$FILENAME'\033[0m"
else
echo -e "\033[31mFailed to upload artifact.\033[0m"
exit 1
fi

132
src/tinker/__init__.py Normal file
View file

@ -0,0 +1,132 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
import typing as _t
from . import types
from ._types import NOT_GIVEN, Omit, NoneType, NotGiven, Transport, ProxiesTypes
from ._utils import file_from_path
from ._client import Timeout, Transport, RequestOptions
from ._models import BaseModel
from ._version import __title__, __version__
from ._response import APIResponse as APIResponse, AsyncAPIResponse as AsyncAPIResponse
from ._constants import DEFAULT_TIMEOUT, DEFAULT_MAX_RETRIES, DEFAULT_CONNECTION_LIMITS
from ._exceptions import (
APIError,
TinkerError,
ConflictError,
NotFoundError,
APIStatusError,
RateLimitError,
APITimeoutError,
BadRequestError,
APIConnectionError,
AuthenticationError,
InternalServerError,
PermissionDeniedError,
UnprocessableEntityError,
APIResponseValidationError,
)
from ._base_client import DefaultHttpxClient, DefaultAioHttpClient, DefaultAsyncHttpxClient
from ._utils._logs import setup_logging as _setup_logging
from .lib.public_interfaces import TrainingClient, ServiceClient, SamplingClient, APIFuture
# Import commonly used types for easier access
from .types import (
AdamParams,
Checkpoint,
CheckpointType,
Datum,
EncodedTextChunk,
ForwardBackwardOutput,
LoraConfig,
ModelID,
ModelInput,
ModelInputChunk,
OptimStepRequest,
OptimStepResponse,
ParsedCheckpointTinkerPath,
SampledSequence,
SampleRequest,
SampleResponse,
SamplingParams,
StopReason,
TensorData,
TensorDtype,
TrainingRun,
)
__all__ = [
# Core clients
"TrainingClient",
"ServiceClient",
"SamplingClient",
"APIFuture",
# Commonly used types
"AdamParams",
"Checkpoint",
"CheckpointType",
"Datum",
"EncodedTextChunk",
"ForwardBackwardOutput",
"LoraConfig",
"ModelID",
"ModelInput",
"ModelInputChunk",
"OptimStepRequest",
"OptimStepResponse",
"ParsedCheckpointTinkerPath",
"SampledSequence",
"SampleRequest",
"SampleResponse",
"SamplingParams",
"StopReason",
"TensorData",
"TensorDtype",
"TrainingRun",
# Client configuration
"Timeout",
"RequestOptions",
# Exception types
"TinkerError",
"APIError",
"APIStatusError",
"APITimeoutError",
"APIConnectionError",
"APIResponseValidationError",
"BadRequestError",
"AuthenticationError",
"PermissionDeniedError",
"NotFoundError",
"ConflictError",
"UnprocessableEntityError",
"RateLimitError",
"InternalServerError",
# Keep types module for advanced use
"types",
# Version info
"__version__",
"__title__",
]
if not _t.TYPE_CHECKING:
from ._utils._resources_proxy import resources as resources
_setup_logging()
# Update the __module__ attribute for exported symbols so that
# error messages point to this module instead of the module
# it was originally defined in, e.g.
# tinker._exceptions.NotFoundError -> tinker.NotFoundError
__locals = locals()
for __name in __all__:
if not __name.startswith("__"):
try:
__locals[__name].__module__ = "tinker"
except (TypeError, AttributeError):
# Some of our exported symbols are builtins which we can't set attributes for.
pass

1997
src/tinker/_base_client.py Normal file

File diff suppressed because it is too large Load diff

668
src/tinker/_client.py Normal file
View file

@ -0,0 +1,668 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from __future__ import annotations
import os
from typing import TYPE_CHECKING, Any, Union, Mapping
from typing_extensions import Self, override
import httpx
from . import _exceptions
from ._qs import Querystring
from ._types import (
NOT_GIVEN,
Omit,
Timeout,
NotGiven,
Transport,
ProxiesTypes,
RequestOptions,
)
from ._utils import is_given, get_async_library
from ._compat import cached_property
from ._version import __version__
from ._streaming import Stream as Stream, AsyncStream as AsyncStream
from ._exceptions import TinkerError, APIStatusError
from ._base_client import (
DEFAULT_MAX_RETRIES,
SyncAPIClient,
AsyncAPIClient,
)
if TYPE_CHECKING:
from .resources import models, futures, service, weights, sampling, training, telemetry
from .resources.models import ModelsResource, AsyncModelsResource
from .resources.futures import FuturesResource, AsyncFuturesResource
from .resources.service import ServiceResource, AsyncServiceResource
from .resources.weights import WeightsResource, AsyncWeightsResource
from .resources.sampling import SamplingResource, AsyncSamplingResource
from .resources.training import TrainingResource, AsyncTrainingResource
from .resources.telemetry import TelemetryResource, AsyncTelemetryResource
__all__ = ["Timeout", "Transport", "ProxiesTypes", "RequestOptions", "Tinker", "AsyncTinker", "Client", "AsyncClient"]
class Tinker(SyncAPIClient):
# client options
api_key: str
def __init__(
self,
*,
api_key: str | None = None,
base_url: str | httpx.URL | None = None,
timeout: Union[float, Timeout, None, NotGiven] = NOT_GIVEN,
max_retries: int = DEFAULT_MAX_RETRIES,
default_headers: Mapping[str, str] | None = None,
default_query: Mapping[str, object] | None = None,
# Configure a custom httpx client.
# We provide a `DefaultHttpxClient` class that you can pass to retain the default values we use for `limits`, `timeout` & `follow_redirects`.
# See the [httpx documentation](https://www.python-httpx.org/api/#client) for more details.
http_client: httpx.Client | None = None,
# Enable or disable schema validation for data returned by the API.
# When enabled an error APIResponseValidationError is raised
# if the API responds with invalid data for the expected schema.
#
# This parameter may be removed or changed in the future.
# If you rely on this feature, please open a GitHub issue
# outlining your use-case to help us decide if it should be
# part of our public interface in the future.
_strict_response_validation: bool = False,
) -> None:
"""Construct a new synchronous Tinker client instance.
This automatically infers the `api_key` argument from the `TINKER_API_KEY` environment variable if it is not provided.
"""
if api_key is None:
api_key = os.environ.get("TINKER_API_KEY")
if api_key is None:
raise TinkerError(
"The api_key client option must be set either by passing api_key to the client or by setting the TINKER_API_KEY environment variable"
)
self.api_key = api_key
if base_url is None:
base_url = os.environ.get("TINKER_BASE_URL")
if base_url is None:
base_url = f"https://tinker.thinkingmachines.dev/services/tinker-prod"
super().__init__(
version=__version__,
base_url=base_url,
max_retries=max_retries,
timeout=timeout,
http_client=http_client,
custom_headers=default_headers,
custom_query=default_query,
_strict_response_validation=_strict_response_validation,
)
self._idempotency_header = "X-Idempotency-Key"
@cached_property
def service(self) -> ServiceResource:
from .resources.service import ServiceResource
return ServiceResource(self)
@cached_property
def training(self) -> TrainingResource:
from .resources.training import TrainingResource
return TrainingResource(self)
@cached_property
def models(self) -> ModelsResource:
from .resources.models import ModelsResource
return ModelsResource(self)
@cached_property
def weights(self) -> WeightsResource:
from .resources.weights import WeightsResource
return WeightsResource(self)
@cached_property
def sampling(self) -> SamplingResource:
from .resources.sampling import SamplingResource
return SamplingResource(self)
@cached_property
def futures(self) -> FuturesResource:
from .resources.futures import FuturesResource
return FuturesResource(self)
@cached_property
def telemetry(self) -> TelemetryResource:
from .resources.telemetry import TelemetryResource
return TelemetryResource(self)
@cached_property
def with_raw_response(self) -> TinkerWithRawResponse:
return TinkerWithRawResponse(self)
@cached_property
def with_streaming_response(self) -> TinkerWithStreamedResponse:
return TinkerWithStreamedResponse(self)
@property
@override
def qs(self) -> Querystring:
return Querystring(array_format="comma")
@property
@override
def auth_headers(self) -> dict[str, str]:
api_key = self.api_key
return {"X-API-Key": api_key}
@property
@override
def default_headers(self) -> dict[str, str | Omit]:
return {
**super().default_headers,
"X-Stainless-Async": "false",
**self._custom_headers,
}
def copy(
self,
*,
api_key: str | None = None,
base_url: str | httpx.URL | None = None,
timeout: float | Timeout | None | NotGiven = NOT_GIVEN,
http_client: httpx.Client | None = None,
max_retries: int | NotGiven = NOT_GIVEN,
default_headers: Mapping[str, str] | None = None,
set_default_headers: Mapping[str, str] | None = None,
default_query: Mapping[str, object] | None = None,
set_default_query: Mapping[str, object] | None = None,
_extra_kwargs: Mapping[str, Any] = {},
) -> Self:
"""
Create a new client instance re-using the same options given to the current client with optional overriding.
"""
if default_headers is not None and set_default_headers is not None:
raise ValueError("The `default_headers` and `set_default_headers` arguments are mutually exclusive")
if default_query is not None and set_default_query is not None:
raise ValueError("The `default_query` and `set_default_query` arguments are mutually exclusive")
headers = self._custom_headers
if default_headers is not None:
headers = {**headers, **default_headers}
elif set_default_headers is not None:
headers = set_default_headers
params = self._custom_query
if default_query is not None:
params = {**params, **default_query}
elif set_default_query is not None:
params = set_default_query
http_client = http_client or self._client
return self.__class__(
api_key=api_key or self.api_key,
base_url=base_url or self.base_url,
timeout=self.timeout if isinstance(timeout, NotGiven) else timeout,
http_client=http_client,
max_retries=max_retries if is_given(max_retries) else self.max_retries,
default_headers=headers,
default_query=params,
**_extra_kwargs,
)
# Alias for `copy` for nicer inline usage, e.g.
# client.with_options(timeout=10).foo.create(...)
with_options = copy
@override
def _make_status_error(
self,
err_msg: str,
*,
body: object,
response: httpx.Response,
) -> APIStatusError:
if response.status_code == 400:
return _exceptions.BadRequestError(err_msg, response=response, body=body)
if response.status_code == 401:
return _exceptions.AuthenticationError(err_msg, response=response, body=body)
if response.status_code == 403:
return _exceptions.PermissionDeniedError(err_msg, response=response, body=body)
if response.status_code == 404:
return _exceptions.NotFoundError(err_msg, response=response, body=body)
if response.status_code == 409:
return _exceptions.ConflictError(err_msg, response=response, body=body)
if response.status_code == 422:
return _exceptions.UnprocessableEntityError(err_msg, response=response, body=body)
if response.status_code == 429:
return _exceptions.RateLimitError(err_msg, response=response, body=body)
if response.status_code >= 500:
return _exceptions.InternalServerError(err_msg, response=response, body=body)
return APIStatusError(err_msg, response=response, body=body)
class AsyncTinker(AsyncAPIClient):
# client options
api_key: str
def __init__(
self,
*,
api_key: str | None = None,
base_url: str | httpx.URL | None = None,
timeout: Union[float, Timeout, None, NotGiven] = NOT_GIVEN,
max_retries: int = DEFAULT_MAX_RETRIES,
default_headers: Mapping[str, str] | None = None,
default_query: Mapping[str, object] | None = None,
# Configure a custom httpx client.
# We provide a `DefaultAsyncHttpxClient` class that you can pass to retain the default values we use for `limits`, `timeout` & `follow_redirects`.
# See the [httpx documentation](https://www.python-httpx.org/api/#asyncclient) for more details.
http_client: httpx.AsyncClient | None = None,
# Enable or disable schema validation for data returned by the API.
# When enabled an error APIResponseValidationError is raised
# if the API responds with invalid data for the expected schema.
#
# This parameter may be removed or changed in the future.
# If you rely on this feature, please open a GitHub issue
# outlining your use-case to help us decide if it should be
# part of our public interface in the future.
_strict_response_validation: bool = False,
) -> None:
"""Construct a new async AsyncTinker client instance.
This automatically infers the `api_key` argument from the `TINKER_API_KEY` environment variable if it is not provided.
"""
if api_key is None:
api_key = os.environ.get("TINKER_API_KEY")
if api_key is None:
raise TinkerError(
"The api_key client option must be set either by passing api_key to the client or by setting the TINKER_API_KEY environment variable"
)
self.api_key = api_key
if base_url is None:
base_url = os.environ.get("TINKER_BASE_URL")
if base_url is None:
base_url = f"https://tinker.thinkingmachines.dev/services/tinker-prod"
super().__init__(
version=__version__,
base_url=base_url,
max_retries=max_retries,
timeout=timeout,
http_client=http_client,
custom_headers=default_headers,
custom_query=default_query,
_strict_response_validation=_strict_response_validation,
)
self._idempotency_header = "X-Idempotency-Key"
@cached_property
def service(self) -> AsyncServiceResource:
from .resources.service import AsyncServiceResource
return AsyncServiceResource(self)
@cached_property
def training(self) -> AsyncTrainingResource:
from .resources.training import AsyncTrainingResource
return AsyncTrainingResource(self)
@cached_property
def models(self) -> AsyncModelsResource:
from .resources.models import AsyncModelsResource
return AsyncModelsResource(self)
@cached_property
def weights(self) -> AsyncWeightsResource:
from .resources.weights import AsyncWeightsResource
return AsyncWeightsResource(self)
@cached_property
def sampling(self) -> AsyncSamplingResource:
from .resources.sampling import AsyncSamplingResource
return AsyncSamplingResource(self)
@cached_property
def futures(self) -> AsyncFuturesResource:
from .resources.futures import AsyncFuturesResource
return AsyncFuturesResource(self)
@cached_property
def telemetry(self) -> AsyncTelemetryResource:
from .resources.telemetry import AsyncTelemetryResource
return AsyncTelemetryResource(self)
@cached_property
def with_raw_response(self) -> AsyncTinkerWithRawResponse:
return AsyncTinkerWithRawResponse(self)
@cached_property
def with_streaming_response(self) -> AsyncTinkerWithStreamedResponse:
return AsyncTinkerWithStreamedResponse(self)
@property
@override
def qs(self) -> Querystring:
return Querystring(array_format="comma")
@property
@override
def auth_headers(self) -> dict[str, str]:
api_key = self.api_key
return {"X-API-Key": api_key}
@property
@override
def default_headers(self) -> dict[str, str | Omit]:
return {
**super().default_headers,
"X-Stainless-Async": f"async:{get_async_library()}",
**self._custom_headers,
}
def copy(
self,
*,
api_key: str | None = None,
base_url: str | httpx.URL | None = None,
timeout: float | Timeout | None | NotGiven = NOT_GIVEN,
http_client: httpx.AsyncClient | None = None,
max_retries: int | NotGiven = NOT_GIVEN,
default_headers: Mapping[str, str] | None = None,
set_default_headers: Mapping[str, str] | None = None,
default_query: Mapping[str, object] | None = None,
set_default_query: Mapping[str, object] | None = None,
_extra_kwargs: Mapping[str, Any] = {},
) -> Self:
"""
Create a new client instance re-using the same options given to the current client with optional overriding.
"""
if default_headers is not None and set_default_headers is not None:
raise ValueError("The `default_headers` and `set_default_headers` arguments are mutually exclusive")
if default_query is not None and set_default_query is not None:
raise ValueError("The `default_query` and `set_default_query` arguments are mutually exclusive")
headers = self._custom_headers
if default_headers is not None:
headers = {**headers, **default_headers}
elif set_default_headers is not None:
headers = set_default_headers
params = self._custom_query
if default_query is not None:
params = {**params, **default_query}
elif set_default_query is not None:
params = set_default_query
http_client = http_client or self._client
return self.__class__(
api_key=api_key or self.api_key,
base_url=base_url or self.base_url,
timeout=self.timeout if isinstance(timeout, NotGiven) else timeout,
http_client=http_client,
max_retries=max_retries if is_given(max_retries) else self.max_retries,
default_headers=headers,
default_query=params,
**_extra_kwargs,
)
# Alias for `copy` for nicer inline usage, e.g.
# client.with_options(timeout=10).foo.create(...)
with_options = copy
@override
def _make_status_error(
self,
err_msg: str,
*,
body: object,
response: httpx.Response,
) -> APIStatusError:
if response.status_code == 400:
return _exceptions.BadRequestError(err_msg, response=response, body=body)
if response.status_code == 401:
return _exceptions.AuthenticationError(err_msg, response=response, body=body)
if response.status_code == 403:
return _exceptions.PermissionDeniedError(err_msg, response=response, body=body)
if response.status_code == 404:
return _exceptions.NotFoundError(err_msg, response=response, body=body)
if response.status_code == 409:
return _exceptions.ConflictError(err_msg, response=response, body=body)
if response.status_code == 422:
return _exceptions.UnprocessableEntityError(err_msg, response=response, body=body)
if response.status_code == 429:
return _exceptions.RateLimitError(err_msg, response=response, body=body)
if response.status_code >= 500:
return _exceptions.InternalServerError(err_msg, response=response, body=body)
return APIStatusError(err_msg, response=response, body=body)
class TinkerWithRawResponse:
_client: Tinker
def __init__(self, client: Tinker) -> None:
self._client = client
@cached_property
def service(self) -> service.ServiceResourceWithRawResponse:
from .resources.service import ServiceResourceWithRawResponse
return ServiceResourceWithRawResponse(self._client.service)
@cached_property
def training(self) -> training.TrainingResourceWithRawResponse:
from .resources.training import TrainingResourceWithRawResponse
return TrainingResourceWithRawResponse(self._client.training)
@cached_property
def models(self) -> models.ModelsResourceWithRawResponse:
from .resources.models import ModelsResourceWithRawResponse
return ModelsResourceWithRawResponse(self._client.models)
@cached_property
def weights(self) -> weights.WeightsResourceWithRawResponse:
from .resources.weights import WeightsResourceWithRawResponse
return WeightsResourceWithRawResponse(self._client.weights)
@cached_property
def sampling(self) -> sampling.SamplingResourceWithRawResponse:
from .resources.sampling import SamplingResourceWithRawResponse
return SamplingResourceWithRawResponse(self._client.sampling)
@cached_property
def futures(self) -> futures.FuturesResourceWithRawResponse:
from .resources.futures import FuturesResourceWithRawResponse
return FuturesResourceWithRawResponse(self._client.futures)
@cached_property
def telemetry(self) -> telemetry.TelemetryResourceWithRawResponse:
from .resources.telemetry import TelemetryResourceWithRawResponse
return TelemetryResourceWithRawResponse(self._client.telemetry)
class AsyncTinkerWithRawResponse:
_client: AsyncTinker
def __init__(self, client: AsyncTinker) -> None:
self._client = client
@cached_property
def service(self) -> service.AsyncServiceResourceWithRawResponse:
from .resources.service import AsyncServiceResourceWithRawResponse
return AsyncServiceResourceWithRawResponse(self._client.service)
@cached_property
def training(self) -> training.AsyncTrainingResourceWithRawResponse:
from .resources.training import AsyncTrainingResourceWithRawResponse
return AsyncTrainingResourceWithRawResponse(self._client.training)
@cached_property
def models(self) -> models.AsyncModelsResourceWithRawResponse:
from .resources.models import AsyncModelsResourceWithRawResponse
return AsyncModelsResourceWithRawResponse(self._client.models)
@cached_property
def weights(self) -> weights.AsyncWeightsResourceWithRawResponse:
from .resources.weights import AsyncWeightsResourceWithRawResponse
return AsyncWeightsResourceWithRawResponse(self._client.weights)
@cached_property
def sampling(self) -> sampling.AsyncSamplingResourceWithRawResponse:
from .resources.sampling import AsyncSamplingResourceWithRawResponse
return AsyncSamplingResourceWithRawResponse(self._client.sampling)
@cached_property
def futures(self) -> futures.AsyncFuturesResourceWithRawResponse:
from .resources.futures import AsyncFuturesResourceWithRawResponse
return AsyncFuturesResourceWithRawResponse(self._client.futures)
@cached_property
def telemetry(self) -> telemetry.AsyncTelemetryResourceWithRawResponse:
from .resources.telemetry import AsyncTelemetryResourceWithRawResponse
return AsyncTelemetryResourceWithRawResponse(self._client.telemetry)
class TinkerWithStreamedResponse:
_client: Tinker
def __init__(self, client: Tinker) -> None:
self._client = client
@cached_property
def service(self) -> service.ServiceResourceWithStreamingResponse:
from .resources.service import ServiceResourceWithStreamingResponse
return ServiceResourceWithStreamingResponse(self._client.service)
@cached_property
def training(self) -> training.TrainingResourceWithStreamingResponse:
from .resources.training import TrainingResourceWithStreamingResponse
return TrainingResourceWithStreamingResponse(self._client.training)
@cached_property
def models(self) -> models.ModelsResourceWithStreamingResponse:
from .resources.models import ModelsResourceWithStreamingResponse
return ModelsResourceWithStreamingResponse(self._client.models)
@cached_property
def weights(self) -> weights.WeightsResourceWithStreamingResponse:
from .resources.weights import WeightsResourceWithStreamingResponse
return WeightsResourceWithStreamingResponse(self._client.weights)
@cached_property
def sampling(self) -> sampling.SamplingResourceWithStreamingResponse:
from .resources.sampling import SamplingResourceWithStreamingResponse
return SamplingResourceWithStreamingResponse(self._client.sampling)
@cached_property
def futures(self) -> futures.FuturesResourceWithStreamingResponse:
from .resources.futures import FuturesResourceWithStreamingResponse
return FuturesResourceWithStreamingResponse(self._client.futures)
@cached_property
def telemetry(self) -> telemetry.TelemetryResourceWithStreamingResponse:
from .resources.telemetry import TelemetryResourceWithStreamingResponse
return TelemetryResourceWithStreamingResponse(self._client.telemetry)
class AsyncTinkerWithStreamedResponse:
_client: AsyncTinker
def __init__(self, client: AsyncTinker) -> None:
self._client = client
@cached_property
def service(self) -> service.AsyncServiceResourceWithStreamingResponse:
from .resources.service import AsyncServiceResourceWithStreamingResponse
return AsyncServiceResourceWithStreamingResponse(self._client.service)
@cached_property
def training(self) -> training.AsyncTrainingResourceWithStreamingResponse:
from .resources.training import AsyncTrainingResourceWithStreamingResponse
return AsyncTrainingResourceWithStreamingResponse(self._client.training)
@cached_property
def models(self) -> models.AsyncModelsResourceWithStreamingResponse:
from .resources.models import AsyncModelsResourceWithStreamingResponse
return AsyncModelsResourceWithStreamingResponse(self._client.models)
@cached_property
def weights(self) -> weights.AsyncWeightsResourceWithStreamingResponse:
from .resources.weights import AsyncWeightsResourceWithStreamingResponse
return AsyncWeightsResourceWithStreamingResponse(self._client.weights)
@cached_property
def sampling(self) -> sampling.AsyncSamplingResourceWithStreamingResponse:
from .resources.sampling import AsyncSamplingResourceWithStreamingResponse
return AsyncSamplingResourceWithStreamingResponse(self._client.sampling)
@cached_property
def futures(self) -> futures.AsyncFuturesResourceWithStreamingResponse:
from .resources.futures import AsyncFuturesResourceWithStreamingResponse
return AsyncFuturesResourceWithStreamingResponse(self._client.futures)
@cached_property
def telemetry(self) -> telemetry.AsyncTelemetryResourceWithStreamingResponse:
from .resources.telemetry import AsyncTelemetryResourceWithStreamingResponse
return AsyncTelemetryResourceWithStreamingResponse(self._client.telemetry)
Client = Tinker
AsyncClient = AsyncTinker

219
src/tinker/_compat.py Normal file
View file

@ -0,0 +1,219 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Union, Generic, TypeVar, Callable, cast, overload
from datetime import date, datetime
from typing_extensions import Self, Literal
import pydantic
from pydantic.fields import FieldInfo
from ._types import IncEx, StrBytesIntFloat
_T = TypeVar("_T")
_ModelT = TypeVar("_ModelT", bound=pydantic.BaseModel)
# --------------- Pydantic v2 compatibility ---------------
# Pyright incorrectly reports some of our functions as overriding a method when they don't
# pyright: reportIncompatibleMethodOverride=false
PYDANTIC_V2 = pydantic.VERSION.startswith("2.")
# v1 re-exports
if TYPE_CHECKING:
def parse_date(value: date | StrBytesIntFloat) -> date: # noqa: ARG001
...
def parse_datetime(value: Union[datetime, StrBytesIntFloat]) -> datetime: # noqa: ARG001
...
def get_args(t: type[Any]) -> tuple[Any, ...]: # noqa: ARG001
...
def is_union(tp: type[Any] | None) -> bool: # noqa: ARG001
...
def get_origin(t: type[Any]) -> type[Any] | None: # noqa: ARG001
...
def is_literal_type(type_: type[Any]) -> bool: # noqa: ARG001
...
def is_typeddict(type_: type[Any]) -> bool: # noqa: ARG001
...
else:
if PYDANTIC_V2:
from pydantic.v1.typing import (
get_args as get_args,
is_union as is_union,
get_origin as get_origin,
is_typeddict as is_typeddict,
is_literal_type as is_literal_type,
)
from pydantic.v1.datetime_parse import parse_date as parse_date, parse_datetime as parse_datetime
else:
from pydantic.typing import (
get_args as get_args,
is_union as is_union,
get_origin as get_origin,
is_typeddict as is_typeddict,
is_literal_type as is_literal_type,
)
from pydantic.datetime_parse import parse_date as parse_date, parse_datetime as parse_datetime
# refactored config
if TYPE_CHECKING:
from pydantic import ConfigDict as ConfigDict
else:
if PYDANTIC_V2:
from pydantic import ConfigDict
else:
# TODO: provide an error message here?
ConfigDict = None
# renamed methods / properties
def parse_obj(model: type[_ModelT], value: object) -> _ModelT:
if PYDANTIC_V2:
return model.model_validate(value)
else:
return cast(_ModelT, model.parse_obj(value)) # pyright: ignore[reportDeprecated, reportUnnecessaryCast]
def field_is_required(field: FieldInfo) -> bool:
if PYDANTIC_V2:
return field.is_required()
return field.required # type: ignore
def field_get_default(field: FieldInfo) -> Any:
value = field.get_default()
if PYDANTIC_V2:
from pydantic_core import PydanticUndefined
if value == PydanticUndefined:
return None
return value
return value
def field_outer_type(field: FieldInfo) -> Any:
if PYDANTIC_V2:
return field.annotation
return field.outer_type_ # type: ignore
def get_model_config(model: type[pydantic.BaseModel]) -> Any:
if PYDANTIC_V2:
return model.model_config
return model.__config__ # type: ignore
def get_model_fields(model: type[pydantic.BaseModel]) -> dict[str, FieldInfo]:
if PYDANTIC_V2:
return model.model_fields
return model.__fields__ # type: ignore
def model_copy(model: _ModelT, *, deep: bool = False) -> _ModelT:
if PYDANTIC_V2:
return model.model_copy(deep=deep)
return model.copy(deep=deep) # type: ignore
def model_json(model: pydantic.BaseModel, *, indent: int | None = None) -> str:
if PYDANTIC_V2:
return model.model_dump_json(indent=indent)
return model.json(indent=indent) # type: ignore
def model_dump(
model: pydantic.BaseModel,
*,
exclude: IncEx | None = None,
exclude_unset: bool = False,
exclude_defaults: bool = False,
warnings: bool = True,
mode: Literal["json", "python"] = "python",
) -> dict[str, Any]:
if PYDANTIC_V2 or hasattr(model, "model_dump"):
return model.model_dump(
mode=mode,
exclude=exclude,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
# warnings are not supported in Pydantic v1
warnings=warnings if PYDANTIC_V2 else True,
)
return cast(
"dict[str, Any]",
model.dict( # pyright: ignore[reportDeprecated, reportUnnecessaryCast]
exclude=exclude,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
),
)
def model_parse(model: type[_ModelT], data: Any) -> _ModelT:
if PYDANTIC_V2:
return model.model_validate(data)
return model.parse_obj(data) # pyright: ignore[reportDeprecated]
# generic models
if TYPE_CHECKING:
class GenericModel(pydantic.BaseModel): ...
else:
if PYDANTIC_V2:
# there no longer needs to be a distinction in v2 but
# we still have to create our own subclass to avoid
# inconsistent MRO ordering errors
class GenericModel(pydantic.BaseModel): ...
else:
import pydantic.generics
class GenericModel(pydantic.generics.GenericModel, pydantic.BaseModel): ...
# cached properties
if TYPE_CHECKING:
cached_property = property
# we define a separate type (copied from typeshed)
# that represents that `cached_property` is `set`able
# at runtime, which differs from `@property`.
#
# this is a separate type as editors likely special case
# `@property` and we don't want to cause issues just to have
# more helpful internal types.
class typed_cached_property(Generic[_T]):
func: Callable[[Any], _T]
attrname: str | None
def __init__(self, func: Callable[[Any], _T]) -> None: ...
@overload
def __get__(self, instance: None, owner: type[Any] | None = None) -> Self: ...
@overload
def __get__(self, instance: object, owner: type[Any] | None = None) -> _T: ...
def __get__(self, instance: object, owner: type[Any] | None = None) -> _T | Self:
raise NotImplementedError()
def __set_name__(self, owner: type[Any], name: str) -> None: ...
# __set__ is not defined at runtime, but @cached_property is designed to be settable
def __set__(self, instance: object, value: _T) -> None: ...
else:
from functools import cached_property as cached_property
typed_cached_property = cached_property

14
src/tinker/_constants.py Normal file
View file

@ -0,0 +1,14 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
import httpx
RAW_RESPONSE_HEADER = "X-Stainless-Raw-Response"
OVERRIDE_CAST_TO_HEADER = "____stainless_override_cast_to"
# default timeout is 1 minute
DEFAULT_TIMEOUT = httpx.Timeout(timeout=60, connect=5.0)
DEFAULT_MAX_RETRIES = 10
DEFAULT_CONNECTION_LIMITS = httpx.Limits(max_connections=1000, max_keepalive_connections=20)
INITIAL_RETRY_DELAY = 0.5
MAX_RETRY_DELAY = 10.0

108
src/tinker/_exceptions.py Normal file
View file

@ -0,0 +1,108 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from __future__ import annotations
from typing_extensions import Literal
import httpx
__all__ = [
"BadRequestError",
"AuthenticationError",
"PermissionDeniedError",
"NotFoundError",
"ConflictError",
"UnprocessableEntityError",
"RateLimitError",
"InternalServerError",
]
class TinkerError(Exception):
pass
class APIError(TinkerError):
message: str
request: httpx.Request
body: object | None
"""The API response body.
If the API responded with a valid JSON structure then this property will be the
decoded result.
If it isn't a valid JSON structure then this will be the raw response.
If there was no response associated with this error then it will be `None`.
"""
def __init__(self, message: str, request: httpx.Request, *, body: object | None) -> None: # noqa: ARG002
super().__init__(message)
self.request = request
self.message = message
self.body = body
class APIResponseValidationError(APIError):
response: httpx.Response
status_code: int
def __init__(self, response: httpx.Response, body: object | None, *, message: str | None = None) -> None:
super().__init__(message or "Data returned by API invalid for expected schema.", response.request, body=body)
self.response = response
self.status_code = response.status_code
class APIStatusError(APIError):
"""Raised when an API response has a status code of 4xx or 5xx."""
response: httpx.Response
status_code: int
def __init__(self, message: str, *, response: httpx.Response, body: object | None) -> None:
super().__init__(message, response.request, body=body)
self.response = response
self.status_code = response.status_code
class APIConnectionError(APIError):
def __init__(self, *, message: str = "Connection error.", request: httpx.Request) -> None:
super().__init__(message, request, body=None)
class APITimeoutError(APIConnectionError):
def __init__(self, request: httpx.Request) -> None:
super().__init__(message="Request timed out.", request=request)
class BadRequestError(APIStatusError):
status_code: Literal[400] = 400 # pyright: ignore[reportIncompatibleVariableOverride]
class AuthenticationError(APIStatusError):
status_code: Literal[401] = 401 # pyright: ignore[reportIncompatibleVariableOverride]
class PermissionDeniedError(APIStatusError):
status_code: Literal[403] = 403 # pyright: ignore[reportIncompatibleVariableOverride]
class NotFoundError(APIStatusError):
status_code: Literal[404] = 404 # pyright: ignore[reportIncompatibleVariableOverride]
class ConflictError(APIStatusError):
status_code: Literal[409] = 409 # pyright: ignore[reportIncompatibleVariableOverride]
class UnprocessableEntityError(APIStatusError):
status_code: Literal[422] = 422 # pyright: ignore[reportIncompatibleVariableOverride]
class RateLimitError(APIStatusError):
status_code: Literal[429] = 429 # pyright: ignore[reportIncompatibleVariableOverride]
class InternalServerError(APIStatusError):
pass

123
src/tinker/_files.py Normal file
View file

@ -0,0 +1,123 @@
from __future__ import annotations
import io
import os
import pathlib
from typing import overload
from typing_extensions import TypeGuard
import anyio
from ._types import (
FileTypes,
FileContent,
RequestFiles,
HttpxFileTypes,
Base64FileInput,
HttpxFileContent,
HttpxRequestFiles,
)
from ._utils import is_tuple_t, is_mapping_t, is_sequence_t
def is_base64_file_input(obj: object) -> TypeGuard[Base64FileInput]:
return isinstance(obj, io.IOBase) or isinstance(obj, os.PathLike)
def is_file_content(obj: object) -> TypeGuard[FileContent]:
return (
isinstance(obj, bytes) or isinstance(obj, tuple) or isinstance(obj, io.IOBase) or isinstance(obj, os.PathLike)
)
def assert_is_file_content(obj: object, *, key: str | None = None) -> None:
if not is_file_content(obj):
prefix = f"Expected entry at `{key}`" if key is not None else f"Expected file input `{obj!r}`"
raise RuntimeError(
f"{prefix} to be bytes, an io.IOBase instance, PathLike or a tuple but received {type(obj)} instead."
) from None
@overload
def to_httpx_files(files: None) -> None: ...
@overload
def to_httpx_files(files: RequestFiles) -> HttpxRequestFiles: ...
def to_httpx_files(files: RequestFiles | None) -> HttpxRequestFiles | None:
if files is None:
return None
if is_mapping_t(files):
files = {key: _transform_file(file) for key, file in files.items()}
elif is_sequence_t(files):
files = [(key, _transform_file(file)) for key, file in files]
else:
raise TypeError(f"Unexpected file type input {type(files)}, expected mapping or sequence")
return files
def _transform_file(file: FileTypes) -> HttpxFileTypes:
if is_file_content(file):
if isinstance(file, os.PathLike):
path = pathlib.Path(file)
return (path.name, path.read_bytes())
return file
if is_tuple_t(file):
return (file[0], read_file_content(file[1]), *file[2:])
raise TypeError(f"Expected file types input to be a FileContent type or to be a tuple")
def read_file_content(file: FileContent) -> HttpxFileContent:
if isinstance(file, os.PathLike):
return pathlib.Path(file).read_bytes()
return file
@overload
async def async_to_httpx_files(files: None) -> None: ...
@overload
async def async_to_httpx_files(files: RequestFiles) -> HttpxRequestFiles: ...
async def async_to_httpx_files(files: RequestFiles | None) -> HttpxRequestFiles | None:
if files is None:
return None
if is_mapping_t(files):
files = {key: await _async_transform_file(file) for key, file in files.items()}
elif is_sequence_t(files):
files = [(key, await _async_transform_file(file)) for key, file in files]
else:
raise TypeError("Unexpected file type input {type(files)}, expected mapping or sequence")
return files
async def _async_transform_file(file: FileTypes) -> HttpxFileTypes:
if is_file_content(file):
if isinstance(file, os.PathLike):
path = anyio.Path(file)
return (path.name, await path.read_bytes())
return file
if is_tuple_t(file):
return (file[0], await async_read_file_content(file[1]), *file[2:])
raise TypeError(f"Expected file types input to be a FileContent type or to be a tuple")
async def async_read_file_content(file: FileContent) -> HttpxFileContent:
if isinstance(file, os.PathLike):
return await anyio.Path(file).read_bytes()
return file

560
src/tinker/_models.py Normal file
View file

@ -0,0 +1,560 @@
from __future__ import annotations
import inspect
from typing import TYPE_CHECKING, Any, Type, Union, Generic, TypeVar, Callable, Optional, cast
from datetime import date, datetime
from typing import TYPE_CHECKING, Any, Callable, Generic, Type, TypeVar, Union, cast
import pydantic
from pydantic.fields import FieldInfo
from typing_extensions import (
List,
Unpack,
Literal,
ClassVar,
ParamSpec,
Protocol,
Required,
TypedDict,
TypeGuard,
Unpack,
final,
runtime_checkable,
)
from ._compat import (
PYDANTIC_V2,
ConfigDict,
field_get_default,
get_args,
get_origin,
is_literal_type,
is_union,
parse_obj,
)
from ._compat import (
GenericModel as BaseGenericModel,
)
from ._constants import RAW_RESPONSE_HEADER
from ._types import (
AnyMapping,
Body,
Headers,
HttpxRequestFiles,
NotGiven,
Query,
Timeout,
)
from ._utils import (
PropertyInfo,
extract_type_arg,
is_annotated_type,
is_given,
is_list,
is_mapping,
is_type_alias_type,
lru_cache,
parse_date,
parse_datetime,
strip_annotated_type,
strip_not_given,
)
if TYPE_CHECKING:
from pydantic_core.core_schema import LiteralSchema, ModelField, ModelFieldsSchema, ModelSchema
__all__ = ["StrictBase", "BaseModel", "GenericModel"]
_T = TypeVar("_T")
_BaseModelT = TypeVar("_BaseModelT", bound=pydantic.BaseModel)
P = ParamSpec("P")
@runtime_checkable
class _ConfigProtocol(Protocol):
allow_population_by_field_name: bool
class StrictBase(pydantic.BaseModel):
"""
Don't allow extra fields, so user errors are caught earlier.
Use this for request types.
"""
model_config = ConfigDict(frozen=True, extra="forbid")
def __str__(self) -> str:
return repr(self)
class BaseModel(pydantic.BaseModel):
"""
Use for classes that may appear in responses. Allow extra fields, so old clients can still work.
"""
# For future-proofing, we ignore extra fields in case the server adds new fields.
model_config = ConfigDict(frozen=True, extra="ignore")
def __str__(self) -> str:
return repr(self)
def _construct_field(value: object, field: FieldInfo, key: str) -> object:
if value is None:
return field_get_default(field)
if PYDANTIC_V2:
type_ = field.annotation
else:
type_ = cast(type, field.outer_type_) # type: ignore
if type_ is None:
raise RuntimeError(f"Unexpected field type is None for {key}")
return construct_type(value=value, type_=type_, metadata=getattr(field, "metadata", None))
def is_basemodel(type_: type) -> bool:
"""Returns whether or not the given type is either a `BaseModel` or a union of `BaseModel`"""
if is_union(type_):
for variant in get_args(type_):
if is_basemodel(variant):
return True
return False
return is_basemodel_type(type_)
def is_basemodel_type(type_: type) -> TypeGuard[type[BaseModel] | type[GenericModel]]:
origin = get_origin(type_) or type_
if not inspect.isclass(origin):
return False
return issubclass(origin, BaseModel) or issubclass(origin, GenericModel)
def build(
base_model_cls: Callable[P, _BaseModelT],
*args: P.args,
**kwargs: P.kwargs,
) -> _BaseModelT:
"""Construct a BaseModel class without validation.
This is useful for cases where you need to instantiate a `BaseModel`
from an API response as this provides type-safe params which isn't supported
by helpers like `construct_type()`.
```py
build(MyModel, my_field_a="foo", my_field_b=123)
```
"""
if args:
raise TypeError(
"Received positional arguments which are not supported; Keyword arguments must be used instead",
)
return cast(_BaseModelT, construct_type(type_=base_model_cls, value=kwargs))
def construct_type_unchecked(*, value: object, type_: type[_T]) -> _T:
"""Loose coercion to the expected type with construction of nested values.
Note: the returned value from this function is not guaranteed to match the
given type.
"""
return cast(_T, construct_type(value=value, type_=type_))
def construct_type(*, value: object, type_: object, metadata: Optional[List[Any]] = None) -> object:
"""Loose coercion to the expected type with construction of nested values.
If the given value does not match the expected type then it is returned as-is.
"""
# store a reference to the original type we were given before we extract any inner
# types so that we can properly resolve forward references in `TypeAliasType` annotations
original_type = None
# we allow `object` as the input type because otherwise, passing things like
# `Literal['value']` will be reported as a type error by type checkers
type_ = cast("type[object]", type_)
if is_type_alias_type(type_):
original_type = type_ # type: ignore[unreachable]
type_ = type_.__value__ # type: ignore[unreachable]
# unwrap `Annotated[T, ...]` -> `T`
if metadata is not None and len(metadata) > 0:
meta: tuple[Any, ...] = tuple(metadata)
elif is_annotated_type(type_):
meta = get_args(type_)[1:]
type_ = extract_type_arg(type_, 0)
else:
meta = tuple()
# we need to use the origin class for any types that are subscripted generics
# e.g. Dict[str, object]
origin = get_origin(type_) or type_
args = get_args(type_)
if is_union(origin):
try:
return validate_type(type_=cast("type[object]", original_type or type_), value=value)
except Exception:
pass
# if the type is a discriminated union then we want to construct the right variant
# in the union, even if the data doesn't match exactly, otherwise we'd break code
# that relies on the constructed class types, e.g.
#
# class FooType:
# kind: Literal['foo']
# value: str
#
# class BarType:
# kind: Literal['bar']
# value: int
#
# without this block, if the data we get is something like `{'kind': 'bar', 'value': 'foo'}` then
# we'd end up constructing `FooType` when it should be `BarType`.
discriminator = _build_discriminated_union_meta(union=type_, meta_annotations=meta)
if discriminator and is_mapping(value):
variant_value = value.get(discriminator.field_alias_from or discriminator.field_name)
if variant_value and isinstance(variant_value, str):
variant_type = discriminator.mapping.get(variant_value)
if variant_type:
return construct_type(type_=variant_type, value=value)
# if the data is not valid, use the first variant that doesn't fail while deserializing
for variant in args:
try:
return construct_type(value=value, type_=variant)
except Exception:
continue
raise RuntimeError(f"Could not convert data into a valid instance of {type_}")
if origin == dict:
if not is_mapping(value):
return value
_, items_type = get_args(type_) # Dict[_, items_type]
return {key: construct_type(value=item, type_=items_type) for key, item in value.items()}
if (
not is_literal_type(type_)
and inspect.isclass(origin)
and (issubclass(origin, BaseModel) or issubclass(origin, GenericModel))
):
if is_list(value):
return [
cast(Any, type_).construct(**entry) if is_mapping(entry) else entry
for entry in value
]
if is_mapping(value):
if issubclass(type_, BaseModel):
return type_.construct(**value) # type: ignore[arg-type]
return cast(Any, type_).construct(**value)
if origin == list:
if not is_list(value):
return value
inner_type = args[0] # List[inner_type]
return [construct_type(value=entry, type_=inner_type) for entry in value]
if origin == float:
if isinstance(value, int):
coerced = float(value)
if coerced != value:
return value
return coerced
return value
if type_ == datetime:
try:
return parse_datetime(value) # type: ignore
except Exception:
return value
if type_ == date:
try:
return parse_date(value) # type: ignore
except Exception:
return value
return value
@runtime_checkable
class CachedDiscriminatorType(Protocol):
__discriminator__: DiscriminatorDetails
class DiscriminatorDetails:
field_name: str
"""The name of the discriminator field in the variant class, e.g.
```py
class Foo(BaseModel):
type: Literal['foo']
```
Will result in field_name='type'
"""
field_alias_from: str | None
"""The name of the discriminator field in the API response, e.g.
```py
class Foo(BaseModel):
type: Literal['foo'] = Field(alias='type_from_api')
```
Will result in field_alias_from='type_from_api'
"""
mapping: dict[str, type]
"""Mapping of discriminator value to variant type, e.g.
{'foo': FooVariant, 'bar': BarVariant}
"""
def __init__(
self,
*,
mapping: dict[str, type],
discriminator_field: str,
discriminator_alias: str | None,
) -> None:
self.mapping = mapping
self.field_name = discriminator_field
self.field_alias_from = discriminator_alias
def _build_discriminated_union_meta(
*, union: type, meta_annotations: tuple[Any, ...]
) -> DiscriminatorDetails | None:
if isinstance(union, CachedDiscriminatorType):
return union.__discriminator__
discriminator_field_name: str | None = None
for annotation in meta_annotations:
if isinstance(annotation, PropertyInfo) and annotation.discriminator is not None:
discriminator_field_name = annotation.discriminator
break
if not discriminator_field_name:
return None
mapping: dict[str, type] = {}
discriminator_alias: str | None = None
for variant in get_args(union):
variant = strip_annotated_type(variant)
if is_basemodel_type(variant):
if PYDANTIC_V2:
field = _extract_field_schema_pv2(variant, discriminator_field_name)
if not field:
continue
# Note: if one variant defines an alias then they all should
discriminator_alias = field.get("serialization_alias")
field_schema = field["schema"]
if field_schema["type"] == "literal":
for entry in cast("LiteralSchema", field_schema)["expected"]:
if isinstance(entry, str):
mapping[entry] = variant
else:
field_info = cast("dict[str, FieldInfo]", variant.__fields__).get(
discriminator_field_name
) # pyright: ignore[reportDeprecated, reportUnnecessaryCast]
if not field_info:
continue
# Note: if one variant defines an alias then they all should
discriminator_alias = field_info.alias
if (annotation := getattr(field_info, "annotation", None)) and is_literal_type(
annotation
):
for entry in get_args(annotation):
if isinstance(entry, str):
mapping[entry] = variant
if not mapping:
return None
details = DiscriminatorDetails(
mapping=mapping,
discriminator_field=discriminator_field_name,
discriminator_alias=discriminator_alias,
)
cast(CachedDiscriminatorType, union).__discriminator__ = details
return details
def _extract_field_schema_pv2(model: type[BaseModel], field_name: str) -> ModelField | None:
schema = model.__pydantic_core_schema__
if schema["type"] == "definitions":
schema = schema["schema"]
if schema["type"] != "model":
return None
schema = cast("ModelSchema", schema)
fields_schema = schema["schema"]
if fields_schema["type"] != "model-fields":
return None
fields_schema = cast("ModelFieldsSchema", fields_schema)
field = fields_schema["fields"].get(field_name)
if not field:
return None
return cast("ModelField", field) # pyright: ignore[reportUnnecessaryCast]
def validate_type(*, type_: type[_T], value: object) -> _T:
"""Strict validation that the given value matches the expected type"""
if inspect.isclass(type_) and issubclass(type_, pydantic.BaseModel):
return cast(_T, parse_obj(type_, value))
return cast(_T, _validate_non_model_type(type_=type_, value=value))
def set_pydantic_config(typ: Any, config: pydantic.ConfigDict) -> None:
"""Add a pydantic config for the given type.
Note: this is a no-op on Pydantic v1.
"""
setattr(typ, "__pydantic_config__", config) # noqa: B010
# our use of subclassing here causes weirdness for type checkers,
# so we just pretend that we don't subclass
if TYPE_CHECKING:
GenericModel = BaseModel
else:
class GenericModel(BaseGenericModel, BaseModel):
pass
if PYDANTIC_V2:
from pydantic import TypeAdapter as _TypeAdapter
_CachedTypeAdapter = cast("TypeAdapter[object]", lru_cache(maxsize=None)(_TypeAdapter))
if TYPE_CHECKING:
from pydantic import TypeAdapter
else:
TypeAdapter = _CachedTypeAdapter
def _validate_non_model_type(*, type_: type[_T], value: object) -> _T:
return TypeAdapter(type_).validate_python(value)
elif not TYPE_CHECKING: # TODO: condition is weird
class RootModel(GenericModel, Generic[_T]):
"""Used as a placeholder to easily convert runtime types to a Pydantic format
to provide validation.
For example:
```py
validated = RootModel[int](__root__="5").__root__
# validated: 5
```
"""
__root__: _T
def _validate_non_model_type(*, type_: type[_T], value: object) -> _T:
model = _create_pydantic_model(type_).validate(value)
return cast(_T, model.__root__)
def _create_pydantic_model(type_: _T) -> Type[RootModel[_T]]:
return RootModel[type_] # type: ignore
class FinalRequestOptionsInput(TypedDict, total=False):
method: Required[str]
url: Required[str]
params: Query
headers: Headers
max_retries: int
timeout: float | Timeout | None
files: HttpxRequestFiles | None
idempotency_key: str
json_data: Body
extra_json: AnyMapping
follow_redirects: bool
@final
class FinalRequestOptions(pydantic.BaseModel):
method: str
url: str
params: Query = {}
headers: Union[Headers, NotGiven] = NotGiven()
max_retries: Union[int, NotGiven] = NotGiven()
timeout: Union[float, Timeout, None, NotGiven] = NotGiven()
files: Union[HttpxRequestFiles, None] = None
idempotency_key: Union[str, None] = None
post_parser: Union[Callable[[Any], Any], NotGiven] = NotGiven()
follow_redirects: Union[bool, None] = None
# It should be noted that we cannot use `json` here as that would override
# a BaseModel method in an incompatible fashion.
json_data: Union[Body, None] = None
extra_json: Union[AnyMapping, None] = None
if PYDANTIC_V2:
model_config: ClassVar[ConfigDict] = ConfigDict(arbitrary_types_allowed=True)
else:
class Config(pydantic.BaseConfig): # pyright: ignore[reportDeprecated]
arbitrary_types_allowed: bool = True
def get_max_retries(self, max_retries: int) -> int:
if isinstance(self.max_retries, NotGiven):
return max_retries
return self.max_retries
def _strip_raw_response_header(self) -> None:
if not is_given(self.headers):
return
if self.headers.get(RAW_RESPONSE_HEADER):
self.headers = {**self.headers}
self.headers.pop(RAW_RESPONSE_HEADER)
# override the `construct` method so that we can run custom transformations.
# this is necessary as we don't want to do any actual runtime type checking
# (which means we can't use validators) but we do want to ensure that `NotGiven`
# values are not present
#
# type ignore required because we're adding explicit types to `**values`
@classmethod
def construct( # type: ignore
cls,
_fields_set: set[str] | None = None,
**values: Unpack[FinalRequestOptionsInput],
) -> FinalRequestOptions:
kwargs: dict[str, Any] = {
# we unconditionally call `strip_not_given` on any value
# as it will just ignore any non-mapping types
key: strip_not_given(value)
for key, value in values.items()
}
if PYDANTIC_V2:
return super().model_construct(_fields_set, **kwargs)
return cast(FinalRequestOptions, super().construct(_fields_set, **kwargs)) # pyright: ignore[reportDeprecated]
if not TYPE_CHECKING:
# type checkers incorrectly complain about this assignment
model_construct = construct

150
src/tinker/_qs.py Normal file
View file

@ -0,0 +1,150 @@
from __future__ import annotations
from typing import Any, List, Tuple, Union, Mapping, TypeVar
from urllib.parse import parse_qs, urlencode
from typing_extensions import Literal, get_args
from ._types import NOT_GIVEN, NotGiven, NotGivenOr
from ._utils import flatten
_T = TypeVar("_T")
ArrayFormat = Literal["comma", "repeat", "indices", "brackets"]
NestedFormat = Literal["dots", "brackets"]
PrimitiveData = Union[str, int, float, bool, None]
# this should be Data = Union[PrimitiveData, "List[Data]", "Tuple[Data]", "Mapping[str, Data]"]
# https://github.com/microsoft/pyright/issues/3555
Data = Union[PrimitiveData, List[Any], Tuple[Any], "Mapping[str, Any]"]
Params = Mapping[str, Data]
class Querystring:
array_format: ArrayFormat
nested_format: NestedFormat
def __init__(
self,
*,
array_format: ArrayFormat = "repeat",
nested_format: NestedFormat = "brackets",
) -> None:
self.array_format = array_format
self.nested_format = nested_format
def parse(self, query: str) -> Mapping[str, object]:
# Note: custom format syntax is not supported yet
return parse_qs(query)
def stringify(
self,
params: Params,
*,
array_format: NotGivenOr[ArrayFormat] = NOT_GIVEN,
nested_format: NotGivenOr[NestedFormat] = NOT_GIVEN,
) -> str:
return urlencode(
self.stringify_items(
params,
array_format=array_format,
nested_format=nested_format,
)
)
def stringify_items(
self,
params: Params,
*,
array_format: NotGivenOr[ArrayFormat] = NOT_GIVEN,
nested_format: NotGivenOr[NestedFormat] = NOT_GIVEN,
) -> list[tuple[str, str]]:
opts = Options(
qs=self,
array_format=array_format,
nested_format=nested_format,
)
return flatten([self._stringify_item(key, value, opts) for key, value in params.items()])
def _stringify_item(
self,
key: str,
value: Data,
opts: Options,
) -> list[tuple[str, str]]:
if isinstance(value, Mapping):
items: list[tuple[str, str]] = []
nested_format = opts.nested_format
for subkey, subvalue in value.items():
items.extend(
self._stringify_item(
# TODO: error if unknown format
f"{key}.{subkey}" if nested_format == "dots" else f"{key}[{subkey}]",
subvalue,
opts,
)
)
return items
if isinstance(value, (list, tuple)):
array_format = opts.array_format
if array_format == "comma":
return [
(
key,
",".join(self._primitive_value_to_str(item) for item in value if item is not None),
),
]
elif array_format == "repeat":
items = []
for item in value:
items.extend(self._stringify_item(key, item, opts))
return items
elif array_format == "indices":
raise NotImplementedError("The array indices format is not supported yet")
elif array_format == "brackets":
items = []
key = key + "[]"
for item in value:
items.extend(self._stringify_item(key, item, opts))
return items
else:
raise NotImplementedError(
f"Unknown array_format value: {array_format}, choose from {', '.join(get_args(ArrayFormat))}"
)
serialised = self._primitive_value_to_str(value)
if not serialised:
return []
return [(key, serialised)]
def _primitive_value_to_str(self, value: PrimitiveData) -> str:
# copied from httpx
if value is True:
return "true"
elif value is False:
return "false"
elif value is None:
return ""
return str(value)
_qs = Querystring()
parse = _qs.parse
stringify = _qs.stringify
stringify_items = _qs.stringify_items
class Options:
array_format: ArrayFormat
nested_format: NestedFormat
def __init__(
self,
qs: Querystring = _qs,
*,
array_format: NotGivenOr[ArrayFormat] = NOT_GIVEN,
nested_format: NotGivenOr[NestedFormat] = NOT_GIVEN,
) -> None:
self.array_format = qs.array_format if isinstance(array_format, NotGiven) else array_format
self.nested_format = qs.nested_format if isinstance(nested_format, NotGiven) else nested_format

43
src/tinker/_resource.py Normal file
View file

@ -0,0 +1,43 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from __future__ import annotations
import time
from typing import TYPE_CHECKING
import anyio
if TYPE_CHECKING:
from ._client import Tinker, AsyncTinker
class SyncAPIResource:
_client: Tinker
def __init__(self, client: Tinker) -> None:
self._client = client
self._get = client.get
self._post = client.post
self._patch = client.patch
self._put = client.put
self._delete = client.delete
self._get_api_list = client.get_api_list
def _sleep(self, seconds: float) -> None:
time.sleep(seconds)
class AsyncAPIResource:
_client: AsyncTinker
def __init__(self, client: AsyncTinker) -> None:
self._client = client
self._get = client.get
self._post = client.post
self._patch = client.patch
self._put = client.put
self._delete = client.delete
self._get_api_list = client.get_api_list
async def _sleep(self, seconds: float) -> None:
await anyio.sleep(seconds)

830
src/tinker/_response.py Normal file
View file

@ -0,0 +1,830 @@
from __future__ import annotations
import os
import inspect
import logging
import datetime
import functools
from types import TracebackType
from typing import (
TYPE_CHECKING,
Any,
Union,
Generic,
TypeVar,
Callable,
Iterator,
AsyncIterator,
cast,
overload,
)
from typing_extensions import Awaitable, ParamSpec, override, get_origin
import anyio
import httpx
import pydantic
from ._types import NoneType
from ._utils import is_given, extract_type_arg, is_annotated_type, is_type_alias_type, extract_type_var_from_base
from ._models import BaseModel, is_basemodel
from ._constants import RAW_RESPONSE_HEADER, OVERRIDE_CAST_TO_HEADER
from ._streaming import Stream, AsyncStream, is_stream_class_type, extract_stream_chunk_type
from ._exceptions import TinkerError, APIResponseValidationError
if TYPE_CHECKING:
from ._models import FinalRequestOptions
from ._base_client import BaseClient
P = ParamSpec("P")
R = TypeVar("R")
_T = TypeVar("_T")
_APIResponseT = TypeVar("_APIResponseT", bound="APIResponse[Any]")
_AsyncAPIResponseT = TypeVar("_AsyncAPIResponseT", bound="AsyncAPIResponse[Any]")
log: logging.Logger = logging.getLogger(__name__)
class BaseAPIResponse(Generic[R]):
_cast_to: type[R]
_client: BaseClient[Any, Any]
_parsed_by_type: dict[type[Any], Any]
_is_sse_stream: bool
_stream_cls: type[Stream[Any]] | type[AsyncStream[Any]] | None
_options: FinalRequestOptions
http_response: httpx.Response
retries_taken: int
"""The number of retries made. If no retries happened this will be `0`"""
def __init__(
self,
*,
raw: httpx.Response,
cast_to: type[R],
client: BaseClient[Any, Any],
stream: bool,
stream_cls: type[Stream[Any]] | type[AsyncStream[Any]] | None,
options: FinalRequestOptions,
retries_taken: int = 0,
) -> None:
self._cast_to = cast_to
self._client = client
self._parsed_by_type = {}
self._is_sse_stream = stream
self._stream_cls = stream_cls
self._options = options
self.http_response = raw
self.retries_taken = retries_taken
@property
def headers(self) -> httpx.Headers:
return self.http_response.headers
@property
def http_request(self) -> httpx.Request:
"""Returns the httpx Request instance associated with the current response."""
return self.http_response.request
@property
def status_code(self) -> int:
return self.http_response.status_code
@property
def url(self) -> httpx.URL:
"""Returns the URL for which the request was made."""
return self.http_response.url
@property
def method(self) -> str:
return self.http_request.method
@property
def http_version(self) -> str:
return self.http_response.http_version
@property
def elapsed(self) -> datetime.timedelta:
"""The time taken for the complete request/response cycle to complete."""
return self.http_response.elapsed
@property
def is_closed(self) -> bool:
"""Whether or not the response body has been closed.
If this is False then there is response data that has not been read yet.
You must either fully consume the response body or call `.close()`
before discarding the response to prevent resource leaks.
"""
return self.http_response.is_closed
@override
def __repr__(self) -> str:
return (
f"<{self.__class__.__name__} [{self.status_code} {self.http_response.reason_phrase}] type={self._cast_to}>"
)
def _parse(self, *, to: type[_T] | None = None) -> R | _T:
cast_to = to if to is not None else self._cast_to
# unwrap `TypeAlias('Name', T)` -> `T`
if is_type_alias_type(cast_to):
cast_to = cast_to.__value__ # type: ignore[unreachable]
# unwrap `Annotated[T, ...]` -> `T`
if cast_to and is_annotated_type(cast_to):
cast_to = extract_type_arg(cast_to, 0)
origin = get_origin(cast_to) or cast_to
if self._is_sse_stream:
if to:
if not is_stream_class_type(to):
raise TypeError(f"Expected custom parse type to be a subclass of {Stream} or {AsyncStream}")
return cast(
_T,
to(
cast_to=extract_stream_chunk_type(
to,
failure_message="Expected custom stream type to be passed with a type argument, e.g. Stream[ChunkType]",
),
response=self.http_response,
client=cast(Any, self._client),
),
)
if self._stream_cls:
return cast(
R,
self._stream_cls(
cast_to=extract_stream_chunk_type(self._stream_cls),
response=self.http_response,
client=cast(Any, self._client),
),
)
stream_cls = cast("type[Stream[Any]] | type[AsyncStream[Any]] | None", self._client._default_stream_cls)
if stream_cls is None:
raise MissingStreamClassError()
return cast(
R,
stream_cls(
cast_to=cast_to,
response=self.http_response,
client=cast(Any, self._client),
),
)
if cast_to is NoneType:
return cast(R, None)
response = self.http_response
if cast_to == str:
return cast(R, response.text)
if cast_to == bytes:
return cast(R, response.content)
if cast_to == int:
return cast(R, int(response.text))
if cast_to == float:
return cast(R, float(response.text))
if cast_to == bool:
return cast(R, response.text.lower() == "true")
if origin == APIResponse:
raise RuntimeError("Unexpected state - cast_to is `APIResponse`")
if inspect.isclass(origin) and issubclass(origin, httpx.Response):
# Because of the invariance of our ResponseT TypeVar, users can subclass httpx.Response
# and pass that class to our request functions. We cannot change the variance to be either
# covariant or contravariant as that makes our usage of ResponseT illegal. We could construct
# the response class ourselves but that is something that should be supported directly in httpx
# as it would be easy to incorrectly construct the Response object due to the multitude of arguments.
if cast_to != httpx.Response:
raise ValueError(f"Subclasses of httpx.Response cannot be passed to `cast_to`")
return cast(R, response)
if (
inspect.isclass(
origin # pyright: ignore[reportUnknownArgumentType]
)
and not issubclass(origin, BaseModel)
and issubclass(origin, pydantic.BaseModel)
):
raise TypeError("Pydantic models must subclass our base model type, e.g. `from tinker import BaseModel`")
if (
cast_to is not object
and not origin is list
and not origin is dict
and not origin is Union
and not issubclass(origin, BaseModel)
):
raise RuntimeError(
f"Unsupported type, expected {cast_to} to be a subclass of {BaseModel}, {dict}, {list}, {Union}, {NoneType}, {str} or {httpx.Response}."
)
# split is required to handle cases where additional information is included
# in the response, e.g. application/json; charset=utf-8
content_type, *_ = response.headers.get("content-type", "*").split(";")
if not content_type.endswith("json"):
if is_basemodel(cast_to):
try:
data = response.json()
except Exception as exc:
log.debug("Could not read JSON from response data due to %s - %s", type(exc), exc)
else:
return self._client._process_response_data(
data=data,
cast_to=cast_to, # type: ignore
response=response,
)
if self._client._strict_response_validation:
raise APIResponseValidationError(
response=response,
message=f"Expected Content-Type response header to be `application/json` but received `{content_type}` instead.",
body=response.text,
)
# If the API responds with content that isn't JSON then we just return
# the (decoded) text without performing any parsing so that you can still
# handle the response however you need to.
return response.text # type: ignore
data = response.json()
return self._client._process_response_data(
data=data,
cast_to=cast_to, # type: ignore
response=response,
)
class APIResponse(BaseAPIResponse[R]):
@overload
def parse(self, *, to: type[_T]) -> _T: ...
@overload
def parse(self) -> R: ...
def parse(self, *, to: type[_T] | None = None) -> R | _T:
"""Returns the rich python representation of this response's data.
For lower-level control, see `.read()`, `.json()`, `.iter_bytes()`.
You can customise the type that the response is parsed into through
the `to` argument, e.g.
```py
from tinker import BaseModel
class MyModel(BaseModel):
foo: str
obj = response.parse(to=MyModel)
print(obj.foo)
```
We support parsing:
- `BaseModel`
- `dict`
- `list`
- `Union`
- `str`
- `int`
- `float`
- `httpx.Response`
"""
cache_key = to if to is not None else self._cast_to
cached = self._parsed_by_type.get(cache_key)
if cached is not None:
return cached # type: ignore[no-any-return]
if not self._is_sse_stream:
self.read()
parsed = self._parse(to=to)
if is_given(self._options.post_parser):
parsed = self._options.post_parser(parsed)
self._parsed_by_type[cache_key] = parsed
return parsed
def read(self) -> bytes:
"""Read and return the binary response content."""
try:
return self.http_response.read()
except httpx.StreamConsumed as exc:
# The default error raised by httpx isn't very
# helpful in our case so we re-raise it with
# a different error message.
raise StreamAlreadyConsumed() from exc
def text(self) -> str:
"""Read and decode the response content into a string."""
self.read()
return self.http_response.text
def json(self) -> object:
"""Read and decode the JSON response content."""
self.read()
return self.http_response.json()
def close(self) -> None:
"""Close the response and release the connection.
Automatically called if the response body is read to completion.
"""
self.http_response.close()
def iter_bytes(self, chunk_size: int | None = None) -> Iterator[bytes]:
"""
A byte-iterator over the decoded response content.
This automatically handles gzip, deflate and brotli encoded responses.
"""
for chunk in self.http_response.iter_bytes(chunk_size):
yield chunk
def iter_text(self, chunk_size: int | None = None) -> Iterator[str]:
"""A str-iterator over the decoded response content
that handles both gzip, deflate, etc but also detects the content's
string encoding.
"""
for chunk in self.http_response.iter_text(chunk_size):
yield chunk
def iter_lines(self) -> Iterator[str]:
"""Like `iter_text()` but will only yield chunks for each line"""
for chunk in self.http_response.iter_lines():
yield chunk
class AsyncAPIResponse(BaseAPIResponse[R]):
@overload
async def parse(self, *, to: type[_T]) -> _T: ...
@overload
async def parse(self) -> R: ...
async def parse(self, *, to: type[_T] | None = None) -> R | _T:
"""Returns the rich python representation of this response's data.
For lower-level control, see `.read()`, `.json()`, `.iter_bytes()`.
You can customise the type that the response is parsed into through
the `to` argument, e.g.
```py
from tinker import BaseModel
class MyModel(BaseModel):
foo: str
obj = response.parse(to=MyModel)
print(obj.foo)
```
We support parsing:
- `BaseModel`
- `dict`
- `list`
- `Union`
- `str`
- `httpx.Response`
"""
cache_key = to if to is not None else self._cast_to
cached = self._parsed_by_type.get(cache_key)
if cached is not None:
return cached # type: ignore[no-any-return]
if not self._is_sse_stream:
await self.read()
parsed = self._parse(to=to)
if is_given(self._options.post_parser):
parsed = self._options.post_parser(parsed)
self._parsed_by_type[cache_key] = parsed
return parsed
async def read(self) -> bytes:
"""Read and return the binary response content."""
try:
return await self.http_response.aread()
except httpx.StreamConsumed as exc:
# the default error raised by httpx isn't very
# helpful in our case so we re-raise it with
# a different error message
raise StreamAlreadyConsumed() from exc
async def text(self) -> str:
"""Read and decode the response content into a string."""
await self.read()
return self.http_response.text
async def json(self) -> object:
"""Read and decode the JSON response content."""
await self.read()
return self.http_response.json()
async def close(self) -> None:
"""Close the response and release the connection.
Automatically called if the response body is read to completion.
"""
await self.http_response.aclose()
async def iter_bytes(self, chunk_size: int | None = None) -> AsyncIterator[bytes]:
"""
A byte-iterator over the decoded response content.
This automatically handles gzip, deflate and brotli encoded responses.
"""
async for chunk in self.http_response.aiter_bytes(chunk_size):
yield chunk
async def iter_text(self, chunk_size: int | None = None) -> AsyncIterator[str]:
"""A str-iterator over the decoded response content
that handles both gzip, deflate, etc but also detects the content's
string encoding.
"""
async for chunk in self.http_response.aiter_text(chunk_size):
yield chunk
async def iter_lines(self) -> AsyncIterator[str]:
"""Like `iter_text()` but will only yield chunks for each line"""
async for chunk in self.http_response.aiter_lines():
yield chunk
class BinaryAPIResponse(APIResponse[bytes]):
"""Subclass of APIResponse providing helpers for dealing with binary data.
Note: If you want to stream the response data instead of eagerly reading it
all at once then you should use `.with_streaming_response` when making
the API request, e.g. `.with_streaming_response.get_binary_response()`
"""
def write_to_file(
self,
file: str | os.PathLike[str],
) -> None:
"""Write the output to the given file.
Accepts a filename or any path-like object, e.g. pathlib.Path
Note: if you want to stream the data to the file instead of writing
all at once then you should use `.with_streaming_response` when making
the API request, e.g. `.with_streaming_response.get_binary_response()`
"""
with open(file, mode="wb") as f:
for data in self.iter_bytes():
f.write(data)
class AsyncBinaryAPIResponse(AsyncAPIResponse[bytes]):
"""Subclass of APIResponse providing helpers for dealing with binary data.
Note: If you want to stream the response data instead of eagerly reading it
all at once then you should use `.with_streaming_response` when making
the API request, e.g. `.with_streaming_response.get_binary_response()`
"""
async def write_to_file(
self,
file: str | os.PathLike[str],
) -> None:
"""Write the output to the given file.
Accepts a filename or any path-like object, e.g. pathlib.Path
Note: if you want to stream the data to the file instead of writing
all at once then you should use `.with_streaming_response` when making
the API request, e.g. `.with_streaming_response.get_binary_response()`
"""
path = anyio.Path(file)
async with await path.open(mode="wb") as f:
async for data in self.iter_bytes():
await f.write(data)
class StreamedBinaryAPIResponse(APIResponse[bytes]):
def stream_to_file(
self,
file: str | os.PathLike[str],
*,
chunk_size: int | None = None,
) -> None:
"""Streams the output to the given file.
Accepts a filename or any path-like object, e.g. pathlib.Path
"""
with open(file, mode="wb") as f:
for data in self.iter_bytes(chunk_size):
f.write(data)
class AsyncStreamedBinaryAPIResponse(AsyncAPIResponse[bytes]):
async def stream_to_file(
self,
file: str | os.PathLike[str],
*,
chunk_size: int | None = None,
) -> None:
"""Streams the output to the given file.
Accepts a filename or any path-like object, e.g. pathlib.Path
"""
path = anyio.Path(file)
async with await path.open(mode="wb") as f:
async for data in self.iter_bytes(chunk_size):
await f.write(data)
class MissingStreamClassError(TypeError):
def __init__(self) -> None:
super().__init__(
"The `stream` argument was set to `True` but the `stream_cls` argument was not given. See `tinker._streaming` for reference",
)
class StreamAlreadyConsumed(TinkerError):
"""
Attempted to read or stream content, but the content has already
been streamed.
This can happen if you use a method like `.iter_lines()` and then attempt
to read th entire response body afterwards, e.g.
```py
response = await client.post(...)
async for line in response.iter_lines():
... # do something with `line`
content = await response.read()
# ^ error
```
If you want this behaviour you'll need to either manually accumulate the response
content or call `await response.read()` before iterating over the stream.
"""
def __init__(self) -> None:
message = (
"Attempted to read or stream some content, but the content has "
"already been streamed. "
"This could be due to attempting to stream the response "
"content more than once."
"\n\n"
"You can fix this by manually accumulating the response content while streaming "
"or by calling `.read()` before starting to stream."
)
super().__init__(message)
class ResponseContextManager(Generic[_APIResponseT]):
"""Context manager for ensuring that a request is not made
until it is entered and that the response will always be closed
when the context manager exits
"""
def __init__(self, request_func: Callable[[], _APIResponseT]) -> None:
self._request_func = request_func
self.__response: _APIResponseT | None = None
def __enter__(self) -> _APIResponseT:
self.__response = self._request_func()
return self.__response
def __exit__(
self,
exc_type: type[BaseException] | None,
exc: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
if self.__response is not None:
self.__response.close()
class AsyncResponseContextManager(Generic[_AsyncAPIResponseT]):
"""Context manager for ensuring that a request is not made
until it is entered and that the response will always be closed
when the context manager exits
"""
def __init__(self, api_request: Awaitable[_AsyncAPIResponseT]) -> None:
self._api_request = api_request
self.__response: _AsyncAPIResponseT | None = None
async def __aenter__(self) -> _AsyncAPIResponseT:
self.__response = await self._api_request
return self.__response
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
if self.__response is not None:
await self.__response.close()
def to_streamed_response_wrapper(func: Callable[P, R]) -> Callable[P, ResponseContextManager[APIResponse[R]]]:
"""Higher order function that takes one of our bound API methods and wraps it
to support streaming and returning the raw `APIResponse` object directly.
"""
@functools.wraps(func)
def wrapped(*args: P.args, **kwargs: P.kwargs) -> ResponseContextManager[APIResponse[R]]:
extra_headers: dict[str, str] = {**(cast(Any, kwargs.get("extra_headers")) or {})}
extra_headers[RAW_RESPONSE_HEADER] = "stream"
kwargs["extra_headers"] = extra_headers
make_request = functools.partial(func, *args, **kwargs)
return ResponseContextManager(cast(Callable[[], APIResponse[R]], make_request))
return wrapped
def async_to_streamed_response_wrapper(
func: Callable[P, Awaitable[R]],
) -> Callable[P, AsyncResponseContextManager[AsyncAPIResponse[R]]]:
"""Higher order function that takes one of our bound API methods and wraps it
to support streaming and returning the raw `APIResponse` object directly.
"""
@functools.wraps(func)
def wrapped(*args: P.args, **kwargs: P.kwargs) -> AsyncResponseContextManager[AsyncAPIResponse[R]]:
extra_headers: dict[str, str] = {**(cast(Any, kwargs.get("extra_headers")) or {})}
extra_headers[RAW_RESPONSE_HEADER] = "stream"
kwargs["extra_headers"] = extra_headers
make_request = func(*args, **kwargs)
return AsyncResponseContextManager(cast(Awaitable[AsyncAPIResponse[R]], make_request))
return wrapped
def to_custom_streamed_response_wrapper(
func: Callable[P, object],
response_cls: type[_APIResponseT],
) -> Callable[P, ResponseContextManager[_APIResponseT]]:
"""Higher order function that takes one of our bound API methods and an `APIResponse` class
and wraps the method to support streaming and returning the given response class directly.
Note: the given `response_cls` *must* be concrete, e.g. `class BinaryAPIResponse(APIResponse[bytes])`
"""
@functools.wraps(func)
def wrapped(*args: P.args, **kwargs: P.kwargs) -> ResponseContextManager[_APIResponseT]:
extra_headers: dict[str, Any] = {**(cast(Any, kwargs.get("extra_headers")) or {})}
extra_headers[RAW_RESPONSE_HEADER] = "stream"
extra_headers[OVERRIDE_CAST_TO_HEADER] = response_cls
kwargs["extra_headers"] = extra_headers
make_request = functools.partial(func, *args, **kwargs)
return ResponseContextManager(cast(Callable[[], _APIResponseT], make_request))
return wrapped
def async_to_custom_streamed_response_wrapper(
func: Callable[P, Awaitable[object]],
response_cls: type[_AsyncAPIResponseT],
) -> Callable[P, AsyncResponseContextManager[_AsyncAPIResponseT]]:
"""Higher order function that takes one of our bound API methods and an `APIResponse` class
and wraps the method to support streaming and returning the given response class directly.
Note: the given `response_cls` *must* be concrete, e.g. `class BinaryAPIResponse(APIResponse[bytes])`
"""
@functools.wraps(func)
def wrapped(*args: P.args, **kwargs: P.kwargs) -> AsyncResponseContextManager[_AsyncAPIResponseT]:
extra_headers: dict[str, Any] = {**(cast(Any, kwargs.get("extra_headers")) or {})}
extra_headers[RAW_RESPONSE_HEADER] = "stream"
extra_headers[OVERRIDE_CAST_TO_HEADER] = response_cls
kwargs["extra_headers"] = extra_headers
make_request = func(*args, **kwargs)
return AsyncResponseContextManager(cast(Awaitable[_AsyncAPIResponseT], make_request))
return wrapped
def to_raw_response_wrapper(func: Callable[P, R]) -> Callable[P, APIResponse[R]]:
"""Higher order function that takes one of our bound API methods and wraps it
to support returning the raw `APIResponse` object directly.
"""
@functools.wraps(func)
def wrapped(*args: P.args, **kwargs: P.kwargs) -> APIResponse[R]:
extra_headers: dict[str, str] = {**(cast(Any, kwargs.get("extra_headers")) or {})}
extra_headers[RAW_RESPONSE_HEADER] = "raw"
kwargs["extra_headers"] = extra_headers
return cast(APIResponse[R], func(*args, **kwargs))
return wrapped
def async_to_raw_response_wrapper(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[AsyncAPIResponse[R]]]:
"""Higher order function that takes one of our bound API methods and wraps it
to support returning the raw `APIResponse` object directly.
"""
@functools.wraps(func)
async def wrapped(*args: P.args, **kwargs: P.kwargs) -> AsyncAPIResponse[R]:
extra_headers: dict[str, str] = {**(cast(Any, kwargs.get("extra_headers")) or {})}
extra_headers[RAW_RESPONSE_HEADER] = "raw"
kwargs["extra_headers"] = extra_headers
return cast(AsyncAPIResponse[R], await func(*args, **kwargs))
return wrapped
def to_custom_raw_response_wrapper(
func: Callable[P, object],
response_cls: type[_APIResponseT],
) -> Callable[P, _APIResponseT]:
"""Higher order function that takes one of our bound API methods and an `APIResponse` class
and wraps the method to support returning the given response class directly.
Note: the given `response_cls` *must* be concrete, e.g. `class BinaryAPIResponse(APIResponse[bytes])`
"""
@functools.wraps(func)
def wrapped(*args: P.args, **kwargs: P.kwargs) -> _APIResponseT:
extra_headers: dict[str, Any] = {**(cast(Any, kwargs.get("extra_headers")) or {})}
extra_headers[RAW_RESPONSE_HEADER] = "raw"
extra_headers[OVERRIDE_CAST_TO_HEADER] = response_cls
kwargs["extra_headers"] = extra_headers
return cast(_APIResponseT, func(*args, **kwargs))
return wrapped
def async_to_custom_raw_response_wrapper(
func: Callable[P, Awaitable[object]],
response_cls: type[_AsyncAPIResponseT],
) -> Callable[P, Awaitable[_AsyncAPIResponseT]]:
"""Higher order function that takes one of our bound API methods and an `APIResponse` class
and wraps the method to support returning the given response class directly.
Note: the given `response_cls` *must* be concrete, e.g. `class BinaryAPIResponse(APIResponse[bytes])`
"""
@functools.wraps(func)
def wrapped(*args: P.args, **kwargs: P.kwargs) -> Awaitable[_AsyncAPIResponseT]:
extra_headers: dict[str, Any] = {**(cast(Any, kwargs.get("extra_headers")) or {})}
extra_headers[RAW_RESPONSE_HEADER] = "raw"
extra_headers[OVERRIDE_CAST_TO_HEADER] = response_cls
kwargs["extra_headers"] = extra_headers
return cast(Awaitable[_AsyncAPIResponseT], func(*args, **kwargs))
return wrapped
def extract_response_type(typ: type[BaseAPIResponse[Any]]) -> type:
"""Given a type like `APIResponse[T]`, returns the generic type variable `T`.
This also handles the case where a concrete subclass is given, e.g.
```py
class MyResponse(APIResponse[bytes]):
...
extract_response_type(MyResponse) -> bytes
```
"""
return extract_type_var_from_base(
typ,
generic_bases=cast("tuple[type, ...]", (BaseAPIResponse, APIResponse, AsyncAPIResponse)),
index=0,
)

333
src/tinker/_streaming.py Normal file
View file

@ -0,0 +1,333 @@
# Note: initially copied from https://github.com/florimondmanca/httpx-sse/blob/master/src/httpx_sse/_decoders.py
from __future__ import annotations
import json
import inspect
from types import TracebackType
from typing import TYPE_CHECKING, Any, Generic, TypeVar, Iterator, AsyncIterator, cast
from typing_extensions import Self, Protocol, TypeGuard, override, get_origin, runtime_checkable
import httpx
from ._utils import extract_type_var_from_base
if TYPE_CHECKING:
from ._client import Tinker, AsyncTinker
_T = TypeVar("_T")
class Stream(Generic[_T]):
"""Provides the core interface to iterate over a synchronous stream response."""
response: httpx.Response
_decoder: SSEBytesDecoder
def __init__(
self,
*,
cast_to: type[_T],
response: httpx.Response,
client: Tinker,
) -> None:
self.response = response
self._cast_to = cast_to
self._client = client
self._decoder = client._make_sse_decoder()
self._iterator = self.__stream__()
def __next__(self) -> _T:
return self._iterator.__next__()
def __iter__(self) -> Iterator[_T]:
for item in self._iterator:
yield item
def _iter_events(self) -> Iterator[ServerSentEvent]:
yield from self._decoder.iter_bytes(self.response.iter_bytes())
def __stream__(self) -> Iterator[_T]:
cast_to = cast(Any, self._cast_to)
response = self.response
process_data = self._client._process_response_data
iterator = self._iter_events()
for sse in iterator:
yield process_data(data=sse.json(), cast_to=cast_to, response=response)
# Ensure the entire stream is consumed
for _sse in iterator:
...
def __enter__(self) -> Self:
return self
def __exit__(
self,
exc_type: type[BaseException] | None,
exc: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
self.close()
def close(self) -> None:
"""
Close the response and release the connection.
Automatically called if the response body is read to completion.
"""
self.response.close()
class AsyncStream(Generic[_T]):
"""Provides the core interface to iterate over an asynchronous stream response."""
response: httpx.Response
_decoder: SSEDecoder | SSEBytesDecoder
def __init__(
self,
*,
cast_to: type[_T],
response: httpx.Response,
client: AsyncTinker,
) -> None:
self.response = response
self._cast_to = cast_to
self._client = client
self._decoder = client._make_sse_decoder()
self._iterator = self.__stream__()
async def __anext__(self) -> _T:
return await self._iterator.__anext__()
async def __aiter__(self) -> AsyncIterator[_T]:
async for item in self._iterator:
yield item
async def _iter_events(self) -> AsyncIterator[ServerSentEvent]:
async for sse in self._decoder.aiter_bytes(self.response.aiter_bytes()):
yield sse
async def __stream__(self) -> AsyncIterator[_T]:
cast_to = cast(Any, self._cast_to)
response = self.response
process_data = self._client._process_response_data
iterator = self._iter_events()
async for sse in iterator:
yield process_data(data=sse.json(), cast_to=cast_to, response=response)
# Ensure the entire stream is consumed
async for _sse in iterator:
...
async def __aenter__(self) -> Self:
return self
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
await self.close()
async def close(self) -> None:
"""
Close the response and release the connection.
Automatically called if the response body is read to completion.
"""
await self.response.aclose()
class ServerSentEvent:
def __init__(
self,
*,
event: str | None = None,
data: str | None = None,
id: str | None = None,
retry: int | None = None,
) -> None:
if data is None:
data = ""
self._id = id
self._data = data
self._event = event or None
self._retry = retry
@property
def event(self) -> str | None:
return self._event
@property
def id(self) -> str | None:
return self._id
@property
def retry(self) -> int | None:
return self._retry
@property
def data(self) -> str:
return self._data
def json(self) -> Any:
return json.loads(self.data)
@override
def __repr__(self) -> str:
return f"ServerSentEvent(event={self.event}, data={self.data}, id={self.id}, retry={self.retry})"
class SSEDecoder:
_data: list[str]
_event: str | None
_retry: int | None
_last_event_id: str | None
def __init__(self) -> None:
self._event = None
self._data = []
self._last_event_id = None
self._retry = None
def iter_bytes(self, iterator: Iterator[bytes]) -> Iterator[ServerSentEvent]:
"""Given an iterator that yields raw binary data, iterate over it & yield every event encountered"""
for chunk in self._iter_chunks(iterator):
# Split before decoding so splitlines() only uses \r and \n
for raw_line in chunk.splitlines():
line = raw_line.decode("utf-8")
sse = self.decode(line)
if sse:
yield sse
def _iter_chunks(self, iterator: Iterator[bytes]) -> Iterator[bytes]:
"""Given an iterator that yields raw binary data, iterate over it and yield individual SSE chunks"""
data = b""
for chunk in iterator:
for line in chunk.splitlines(keepends=True):
data += line
if data.endswith((b"\r\r", b"\n\n", b"\r\n\r\n")):
yield data
data = b""
if data:
yield data
async def aiter_bytes(self, iterator: AsyncIterator[bytes]) -> AsyncIterator[ServerSentEvent]:
"""Given an iterator that yields raw binary data, iterate over it & yield every event encountered"""
async for chunk in self._aiter_chunks(iterator):
# Split before decoding so splitlines() only uses \r and \n
for raw_line in chunk.splitlines():
line = raw_line.decode("utf-8")
sse = self.decode(line)
if sse:
yield sse
async def _aiter_chunks(self, iterator: AsyncIterator[bytes]) -> AsyncIterator[bytes]:
"""Given an iterator that yields raw binary data, iterate over it and yield individual SSE chunks"""
data = b""
async for chunk in iterator:
for line in chunk.splitlines(keepends=True):
data += line
if data.endswith((b"\r\r", b"\n\n", b"\r\n\r\n")):
yield data
data = b""
if data:
yield data
def decode(self, line: str) -> ServerSentEvent | None:
# See: https://html.spec.whatwg.org/multipage/server-sent-events.html#event-stream-interpretation # noqa: E501
if not line:
if not self._event and not self._data and not self._last_event_id and self._retry is None:
return None
sse = ServerSentEvent(
event=self._event,
data="\n".join(self._data),
id=self._last_event_id,
retry=self._retry,
)
# NOTE: as per the SSE spec, do not reset last_event_id.
self._event = None
self._data = []
self._retry = None
return sse
if line.startswith(":"):
return None
fieldname, _, value = line.partition(":")
if value.startswith(" "):
value = value[1:]
if fieldname == "event":
self._event = value
elif fieldname == "data":
self._data.append(value)
elif fieldname == "id":
if "\0" in value:
pass
else:
self._last_event_id = value
elif fieldname == "retry":
try:
self._retry = int(value)
except (TypeError, ValueError):
pass
else:
pass # Field is ignored.
return None
@runtime_checkable
class SSEBytesDecoder(Protocol):
def iter_bytes(self, iterator: Iterator[bytes]) -> Iterator[ServerSentEvent]:
"""Given an iterator that yields raw binary data, iterate over it & yield every event encountered"""
...
def aiter_bytes(self, iterator: AsyncIterator[bytes]) -> AsyncIterator[ServerSentEvent]:
"""Given an async iterator that yields raw binary data, iterate over it & yield every event encountered"""
...
def is_stream_class_type(typ: type) -> TypeGuard[type[Stream[object]] | type[AsyncStream[object]]]:
"""TypeGuard for determining whether or not the given type is a subclass of `Stream` / `AsyncStream`"""
origin = get_origin(typ) or typ
return inspect.isclass(origin) and issubclass(origin, (Stream, AsyncStream))
def extract_stream_chunk_type(
stream_cls: type,
*,
failure_message: str | None = None,
) -> type:
"""Given a type like `Stream[T]`, returns the generic type variable `T`.
This also handles the case where a concrete subclass is given, e.g.
```py
class MyStream(Stream[bytes]):
...
extract_stream_chunk_type(MyStream) -> bytes
```
"""
from ._base_client import Stream, AsyncStream
return extract_type_var_from_base(
stream_cls,
index=0,
generic_bases=cast("tuple[type, ...]", (Stream, AsyncStream)),
failure_message=failure_message,
)

219
src/tinker/_types.py Normal file
View file

@ -0,0 +1,219 @@
from __future__ import annotations
from os import PathLike
from typing import (
IO,
TYPE_CHECKING,
Any,
Dict,
List,
Type,
Tuple,
Union,
Mapping,
TypeVar,
Callable,
Optional,
Sequence,
)
from typing_extensions import Set, Literal, Protocol, TypeAlias, TypedDict, override, runtime_checkable
import httpx
import pydantic
from httpx import URL, Proxy, Timeout, Response, BaseTransport, AsyncBaseTransport
if TYPE_CHECKING:
from ._models import BaseModel
from ._response import APIResponse, AsyncAPIResponse
Transport = BaseTransport
AsyncTransport = AsyncBaseTransport
Query = Mapping[str, object]
Body = object
AnyMapping = Mapping[str, object]
ModelT = TypeVar("ModelT", bound=pydantic.BaseModel)
_T = TypeVar("_T")
# Approximates httpx internal ProxiesTypes and RequestFiles types
# while adding support for `PathLike` instances
ProxiesDict = Dict["str | URL", Union[None, str, URL, Proxy]]
ProxiesTypes = Union[str, Proxy, ProxiesDict]
if TYPE_CHECKING:
Base64FileInput = Union[IO[bytes], PathLike[str]]
FileContent = Union[IO[bytes], bytes, PathLike[str]]
else:
Base64FileInput = Union[IO[bytes], PathLike]
FileContent = Union[IO[bytes], bytes, PathLike] # PathLike is not subscriptable in Python 3.8.
FileTypes = Union[
# file (or bytes)
FileContent,
# (filename, file (or bytes))
Tuple[Optional[str], FileContent],
# (filename, file (or bytes), content_type)
Tuple[Optional[str], FileContent, Optional[str]],
# (filename, file (or bytes), content_type, headers)
Tuple[Optional[str], FileContent, Optional[str], Mapping[str, str]],
]
RequestFiles = Union[Mapping[str, FileTypes], Sequence[Tuple[str, FileTypes]]]
# duplicate of the above but without our custom file support
HttpxFileContent = Union[IO[bytes], bytes]
HttpxFileTypes = Union[
# file (or bytes)
HttpxFileContent,
# (filename, file (or bytes))
Tuple[Optional[str], HttpxFileContent],
# (filename, file (or bytes), content_type)
Tuple[Optional[str], HttpxFileContent, Optional[str]],
# (filename, file (or bytes), content_type, headers)
Tuple[Optional[str], HttpxFileContent, Optional[str], Mapping[str, str]],
]
HttpxRequestFiles = Union[Mapping[str, HttpxFileTypes], Sequence[Tuple[str, HttpxFileTypes]]]
# Workaround to support (cast_to: Type[ResponseT]) -> ResponseT
# where ResponseT includes `None`. In order to support directly
# passing `None`, overloads would have to be defined for every
# method that uses `ResponseT` which would lead to an unacceptable
# amount of code duplication and make it unreadable. See _base_client.py
# for example usage.
#
# This unfortunately means that you will either have
# to import this type and pass it explicitly:
#
# from tinker import NoneType
# client.get('/foo', cast_to=NoneType)
#
# or build it yourself:
#
# client.get('/foo', cast_to=type(None))
if TYPE_CHECKING:
NoneType: Type[None]
else:
NoneType = type(None)
class RequestOptions(TypedDict, total=False):
headers: Headers
max_retries: int
timeout: float | Timeout | None
params: Query
extra_json: AnyMapping
idempotency_key: str
follow_redirects: bool
# Sentinel class used until PEP 0661 is accepted
class NotGiven:
"""
A sentinel singleton class used to distinguish omitted keyword arguments
from those passed in with the value None (which may have different behavior).
For example:
```py
def get(timeout: Union[int, NotGiven, None] = NotGiven()) -> Response: ...
get(timeout=1) # 1s timeout
get(timeout=None) # No timeout
get() # Default timeout behavior, which may not be statically known at the method definition.
```
"""
def __bool__(self) -> Literal[False]:
return False
@override
def __repr__(self) -> str:
return "NOT_GIVEN"
NotGivenOr = Union[_T, NotGiven]
NOT_GIVEN = NotGiven()
class Omit:
"""In certain situations you need to be able to represent a case where a default value has
to be explicitly removed and `None` is not an appropriate substitute, for example:
```py
# as the default `Content-Type` header is `application/json` that will be sent
client.post("/upload/files", files={"file": b"my raw file content"})
# you can't explicitly override the header as it has to be dynamically generated
# to look something like: 'multipart/form-data; boundary=0d8382fcf5f8c3be01ca2e11002d2983'
client.post(..., headers={"Content-Type": "multipart/form-data"})
# instead you can remove the default `application/json` header by passing Omit
client.post(..., headers={"Content-Type": Omit()})
```
"""
def __bool__(self) -> Literal[False]:
return False
@runtime_checkable
class ModelBuilderProtocol(Protocol):
@classmethod
def build(
cls: type[_T],
*,
response: Response,
data: object,
) -> _T: ...
Headers = Mapping[str, Union[str, Omit]]
class HeadersLikeProtocol(Protocol):
def get(self, __key: str) -> str | None: ...
HeadersLike = Union[Headers, HeadersLikeProtocol]
ResponseT = TypeVar(
"ResponseT",
bound=Union[
object,
str,
None,
"BaseModel",
List[Any],
Dict[str, Any],
Response,
ModelBuilderProtocol,
"APIResponse[Any]",
"AsyncAPIResponse[Any]",
],
)
StrBytesIntFloat = Union[str, bytes, int, float]
# Note: copied from Pydantic
# https://github.com/pydantic/pydantic/blob/6f31f8f68ef011f84357330186f603ff295312fd/pydantic/main.py#L79
IncEx: TypeAlias = Union[Set[int], Set[str], Mapping[int, Union["IncEx", bool]], Mapping[str, Union["IncEx", bool]]]
PostParser = Callable[[Any], Any]
@runtime_checkable
class InheritsGeneric(Protocol):
"""Represents a type that has inherited from `Generic`
The `__orig_bases__` property can be used to determine the resolved
type variable for a given base class.
"""
__orig_bases__: tuple[_GenericAlias]
class _GenericAlias(Protocol):
__origin__: type[object]
class HttpxSendArgs(TypedDict, total=False):
auth: httpx.Auth
follow_redirects: bool

View file

@ -0,0 +1,57 @@
from ._sync import asyncify as asyncify
from ._proxy import LazyProxy as LazyProxy
from ._utils import (
flatten as flatten,
is_dict as is_dict,
is_list as is_list,
is_given as is_given,
is_tuple as is_tuple,
json_safe as json_safe,
lru_cache as lru_cache,
is_mapping as is_mapping,
is_tuple_t as is_tuple_t,
parse_date as parse_date,
is_iterable as is_iterable,
is_sequence as is_sequence,
coerce_float as coerce_float,
is_mapping_t as is_mapping_t,
removeprefix as removeprefix,
removesuffix as removesuffix,
extract_files as extract_files,
is_sequence_t as is_sequence_t,
required_args as required_args,
coerce_boolean as coerce_boolean,
coerce_integer as coerce_integer,
file_from_path as file_from_path,
parse_datetime as parse_datetime,
strip_not_given as strip_not_given,
deepcopy_minimal as deepcopy_minimal,
get_async_library as get_async_library,
maybe_coerce_float as maybe_coerce_float,
get_required_header as get_required_header,
maybe_coerce_boolean as maybe_coerce_boolean,
maybe_coerce_integer as maybe_coerce_integer,
)
from ._typing import (
is_list_type as is_list_type,
is_union_type as is_union_type,
extract_type_arg as extract_type_arg,
is_iterable_type as is_iterable_type,
is_required_type as is_required_type,
is_annotated_type as is_annotated_type,
is_type_alias_type as is_type_alias_type,
strip_annotated_type as strip_annotated_type,
extract_type_var_from_base as extract_type_var_from_base,
)
from ._streams import consume_sync_iterator as consume_sync_iterator, consume_async_iterator as consume_async_iterator
from ._transform import (
PropertyInfo as PropertyInfo,
transform as transform,
async_transform as async_transform,
maybe_transform as maybe_transform,
async_maybe_transform as async_maybe_transform,
)
from ._reflection import (
function_has_argument as function_has_argument,
assert_signatures_in_sync as assert_signatures_in_sync,
)

View file

@ -0,0 +1,25 @@
import os
import logging
logger: logging.Logger = logging.getLogger("tinker")
httpx_logger: logging.Logger = logging.getLogger("httpx")
def _basic_config() -> None:
# e.g. [2023-10-05 14:12:26 - tinker._base_client:818 - DEBUG] HTTP Request: POST http://127.0.0.1:4010/foo/bar "200 OK"
logging.basicConfig(
format="[%(asctime)s - %(name)s:%(lineno)d - %(levelname)s] %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
def setup_logging() -> None:
env = os.environ.get("TINKER_LOG")
if env == "debug":
_basic_config()
logger.setLevel(logging.DEBUG)
httpx_logger.setLevel(logging.DEBUG)
elif env == "info":
_basic_config()
logger.setLevel(logging.INFO)
httpx_logger.setLevel(logging.INFO)

View file

@ -0,0 +1,65 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Generic, TypeVar, Iterable, cast
from typing_extensions import override
T = TypeVar("T")
class LazyProxy(Generic[T], ABC):
"""Implements data methods to pretend that an instance is another instance.
This includes forwarding attribute access and other methods.
"""
# Note: we have to special case proxies that themselves return proxies
# to support using a proxy as a catch-all for any random access, e.g. `proxy.foo.bar.baz`
def __getattr__(self, attr: str) -> object:
proxied = self.__get_proxied__()
if isinstance(proxied, LazyProxy):
return proxied # pyright: ignore
return getattr(proxied, attr)
@override
def __repr__(self) -> str:
proxied = self.__get_proxied__()
if isinstance(proxied, LazyProxy):
return proxied.__class__.__name__
return repr(self.__get_proxied__())
@override
def __str__(self) -> str:
proxied = self.__get_proxied__()
if isinstance(proxied, LazyProxy):
return proxied.__class__.__name__
return str(proxied)
@override
def __dir__(self) -> Iterable[str]:
proxied = self.__get_proxied__()
if isinstance(proxied, LazyProxy):
return []
return proxied.__dir__()
@property # type: ignore
@override
def __class__(self) -> type: # pyright: ignore
try:
proxied = self.__get_proxied__()
except Exception:
return type(self)
if issubclass(type(proxied), LazyProxy):
return type(proxied)
return proxied.__class__
def __get_proxied__(self) -> T:
return self.__load__()
def __as_proxied__(self) -> T:
"""Helper method that returns the current proxy, typed as the loaded object"""
return cast(T, self)
@abstractmethod
def __load__(self) -> T: ...

View file

@ -0,0 +1,42 @@
from __future__ import annotations
import inspect
from typing import Any, Callable
def function_has_argument(func: Callable[..., Any], arg_name: str) -> bool:
"""Returns whether or not the given function has a specific parameter"""
sig = inspect.signature(func)
return arg_name in sig.parameters
def assert_signatures_in_sync(
source_func: Callable[..., Any],
check_func: Callable[..., Any],
*,
exclude_params: set[str] = set(),
) -> None:
"""Ensure that the signature of the second function matches the first."""
check_sig = inspect.signature(check_func)
source_sig = inspect.signature(source_func)
errors: list[str] = []
for name, source_param in source_sig.parameters.items():
if name in exclude_params:
continue
custom_param = check_sig.parameters.get(name)
if not custom_param:
errors.append(f"the `{name}` param is missing")
continue
if custom_param.annotation != source_param.annotation:
errors.append(
f"types for the `{name}` param are do not match; source={repr(source_param.annotation)} checking={repr(custom_param.annotation)}"
)
continue
if errors:
raise AssertionError(f"{len(errors)} errors encountered when comparing signatures:\n\n" + "\n\n".join(errors))

View file

@ -0,0 +1,24 @@
from __future__ import annotations
from typing import Any
from typing_extensions import override
from ._proxy import LazyProxy
class ResourcesProxy(LazyProxy[Any]):
"""A proxy for the `tinker.resources` module.
This is used so that we can lazily import `tinker.resources` only when
needed *and* so that users can just import `tinker` and reference `tinker.resources`
"""
@override
def __load__(self) -> Any:
import importlib
mod = importlib.import_module("tinker.resources")
return mod
resources = ResourcesProxy().__as_proxied__()

View file

@ -0,0 +1,12 @@
from typing import Any
from typing_extensions import Iterator, AsyncIterator
def consume_sync_iterator(iterator: Iterator[Any]) -> None:
for _ in iterator:
...
async def consume_async_iterator(iterator: AsyncIterator[Any]) -> None:
async for _ in iterator:
...

View file

@ -0,0 +1,86 @@
from __future__ import annotations
import sys
import asyncio
import functools
import contextvars
from typing import Any, TypeVar, Callable, Awaitable
from typing_extensions import ParamSpec
import anyio
import sniffio
import anyio.to_thread
T_Retval = TypeVar("T_Retval")
T_ParamSpec = ParamSpec("T_ParamSpec")
if sys.version_info >= (3, 9):
_asyncio_to_thread = asyncio.to_thread
else:
# backport of https://docs.python.org/3/library/asyncio-task.html#asyncio.to_thread
# for Python 3.8 support
async def _asyncio_to_thread(
func: Callable[T_ParamSpec, T_Retval], /, *args: T_ParamSpec.args, **kwargs: T_ParamSpec.kwargs
) -> Any:
"""Asynchronously run function *func* in a separate thread.
Any *args and **kwargs supplied for this function are directly passed
to *func*. Also, the current :class:`contextvars.Context` is propagated,
allowing context variables from the main thread to be accessed in the
separate thread.
Returns a coroutine that can be awaited to get the eventual result of *func*.
"""
loop = asyncio.events.get_running_loop()
ctx = contextvars.copy_context()
func_call = functools.partial(ctx.run, func, *args, **kwargs)
return await loop.run_in_executor(None, func_call)
async def to_thread(
func: Callable[T_ParamSpec, T_Retval], /, *args: T_ParamSpec.args, **kwargs: T_ParamSpec.kwargs
) -> T_Retval:
if sniffio.current_async_library() == "asyncio":
return await _asyncio_to_thread(func, *args, **kwargs)
return await anyio.to_thread.run_sync(
functools.partial(func, *args, **kwargs),
)
# inspired by `asyncer`, https://github.com/tiangolo/asyncer
def asyncify(function: Callable[T_ParamSpec, T_Retval]) -> Callable[T_ParamSpec, Awaitable[T_Retval]]:
"""
Take a blocking function and create an async one that receives the same
positional and keyword arguments. For python version 3.9 and above, it uses
asyncio.to_thread to run the function in a separate thread. For python version
3.8, it uses locally defined copy of the asyncio.to_thread function which was
introduced in python 3.9.
Usage:
```python
def blocking_func(arg1, arg2, kwarg1=None):
# blocking code
return result
result = asyncify(blocking_function)(arg1, arg2, kwarg1=value1)
```
## Arguments
`function`: a blocking regular callable (e.g. a function)
## Return
An async function that takes the same positional and keyword arguments as the
original one, that when called runs the same original function in a thread worker
and returns the result.
"""
async def wrapper(*args: T_ParamSpec.args, **kwargs: T_ParamSpec.kwargs) -> T_Retval:
return await to_thread(function, *args, **kwargs)
return wrapper

View file

@ -0,0 +1,447 @@
from __future__ import annotations
import io
import base64
import pathlib
from typing import Any, Mapping, TypeVar, cast
from datetime import date, datetime
from typing_extensions import Literal, get_args, override, get_type_hints as _get_type_hints
import anyio
import pydantic
from ._utils import (
is_list,
is_given,
lru_cache,
is_mapping,
is_iterable,
)
from .._files import is_base64_file_input
from ._typing import (
is_list_type,
is_union_type,
extract_type_arg,
is_iterable_type,
is_required_type,
is_annotated_type,
strip_annotated_type,
)
from .._compat import get_origin, model_dump, is_typeddict
_T = TypeVar("_T")
# TODO: support for drilling globals() and locals()
# TODO: ensure works correctly with forward references in all cases
PropertyFormat = Literal["iso8601", "base64", "custom"]
class PropertyInfo:
"""Metadata class to be used in Annotated types to provide information about a given type.
For example:
class MyParams(TypedDict):
account_holder_name: Annotated[str, PropertyInfo(alias='accountHolderName')]
This means that {'account_holder_name': 'Robert'} will be transformed to {'accountHolderName': 'Robert'} before being sent to the API.
"""
alias: str | None
format: PropertyFormat | None
format_template: str | None
discriminator: str | None
def __init__(
self,
*,
alias: str | None = None,
format: PropertyFormat | None = None,
format_template: str | None = None,
discriminator: str | None = None,
) -> None:
self.alias = alias
self.format = format
self.format_template = format_template
self.discriminator = discriminator
@override
def __repr__(self) -> str:
return f"{self.__class__.__name__}(alias='{self.alias}', format={self.format}, format_template='{self.format_template}', discriminator='{self.discriminator}')"
def maybe_transform(
data: object,
expected_type: object,
) -> Any | None:
"""Wrapper over `transform()` that allows `None` to be passed.
See `transform()` for more details.
"""
if data is None:
return None
return transform(data, expected_type)
# Wrapper over _transform_recursive providing fake types
def transform(
data: _T,
expected_type: object,
) -> _T:
"""Transform dictionaries based off of type information from the given type, for example:
```py
class Params(TypedDict, total=False):
card_id: Required[Annotated[str, PropertyInfo(alias="cardID")]]
transformed = transform({"card_id": "<my card ID>"}, Params)
# {'cardID': '<my card ID>'}
```
Any keys / data that does not have type information given will be included as is.
It should be noted that the transformations that this function does are not represented in the type system.
"""
transformed = _transform_recursive(data, annotation=cast(type, expected_type))
return cast(_T, transformed)
@lru_cache(maxsize=8096)
def _get_annotated_type(type_: type) -> type | None:
"""If the given type is an `Annotated` type then it is returned, if not `None` is returned.
This also unwraps the type when applicable, e.g. `Required[Annotated[T, ...]]`
"""
if is_required_type(type_):
# Unwrap `Required[Annotated[T, ...]]` to `Annotated[T, ...]`
type_ = get_args(type_)[0]
if is_annotated_type(type_):
return type_
return None
def _maybe_transform_key(key: str, type_: type) -> str:
"""Transform the given `data` based on the annotations provided in `type_`.
Note: this function only looks at `Annotated` types that contain `PropertyInfo` metadata.
"""
annotated_type = _get_annotated_type(type_)
if annotated_type is None:
# no `Annotated` definition for this type, no transformation needed
return key
# ignore the first argument as it is the actual type
annotations = get_args(annotated_type)[1:]
for annotation in annotations:
if isinstance(annotation, PropertyInfo) and annotation.alias is not None:
return annotation.alias
return key
def _no_transform_needed(annotation: type) -> bool:
return annotation == float or annotation == int
def _transform_recursive(
data: object,
*,
annotation: type,
inner_type: type | None = None,
) -> object:
"""Transform the given data against the expected type.
Args:
annotation: The direct type annotation given to the particular piece of data.
This may or may not be wrapped in metadata types, e.g. `Required[T]`, `Annotated[T, ...]` etc
inner_type: If applicable, this is the "inside" type. This is useful in certain cases where the outside type
is a container type such as `List[T]`. In that case `inner_type` should be set to `T` so that each entry in
the list can be transformed using the metadata from the container type.
Defaults to the same value as the `annotation` argument.
"""
if inner_type is None:
inner_type = annotation
stripped_type = strip_annotated_type(inner_type)
origin = get_origin(stripped_type) or stripped_type
if is_typeddict(stripped_type) and is_mapping(data):
return _transform_typeddict(data, stripped_type)
if origin == dict and is_mapping(data):
items_type = get_args(stripped_type)[1]
return {key: _transform_recursive(value, annotation=items_type) for key, value in data.items()}
if (
# List[T]
(is_list_type(stripped_type) and is_list(data))
# Iterable[T]
or (is_iterable_type(stripped_type) and is_iterable(data) and not isinstance(data, str))
):
# dicts are technically iterable, but it is an iterable on the keys of the dict and is not usually
# intended as an iterable, so we don't transform it.
if isinstance(data, dict):
return cast(object, data)
inner_type = extract_type_arg(stripped_type, 0)
if _no_transform_needed(inner_type):
# for some types there is no need to transform anything, so we can get a small
# perf boost from skipping that work.
#
# but we still need to convert to a list to ensure the data is json-serializable
if is_list(data):
return data
return list(data)
return [_transform_recursive(d, annotation=annotation, inner_type=inner_type) for d in data]
if is_union_type(stripped_type):
# For union types we run the transformation against all subtypes to ensure that everything is transformed.
#
# TODO: there may be edge cases where the same normalized field name will transform to two different names
# in different subtypes.
for subtype in get_args(stripped_type):
data = _transform_recursive(data, annotation=annotation, inner_type=subtype)
return data
if isinstance(data, pydantic.BaseModel):
return model_dump(data, exclude_unset=True, mode="json")
annotated_type = _get_annotated_type(annotation)
if annotated_type is None:
return data
# ignore the first argument as it is the actual type
annotations = get_args(annotated_type)[1:]
for annotation in annotations:
if isinstance(annotation, PropertyInfo) and annotation.format is not None:
return _format_data(data, annotation.format, annotation.format_template)
return data
def _format_data(data: object, format_: PropertyFormat, format_template: str | None) -> object:
if isinstance(data, (date, datetime)):
if format_ == "iso8601":
return data.isoformat()
if format_ == "custom" and format_template is not None:
return data.strftime(format_template)
if format_ == "base64" and is_base64_file_input(data):
binary: str | bytes | None = None
if isinstance(data, pathlib.Path):
binary = data.read_bytes()
elif isinstance(data, io.IOBase):
binary = data.read()
if isinstance(binary, str): # type: ignore[unreachable]
binary = binary.encode()
if not isinstance(binary, bytes):
raise RuntimeError(f"Could not read bytes from {data}; Received {type(binary)}")
return base64.b64encode(binary).decode("ascii")
return data
def _transform_typeddict(
data: Mapping[str, object],
expected_type: type,
) -> Mapping[str, object]:
result: dict[str, object] = {}
annotations = get_type_hints(expected_type, include_extras=True)
for key, value in data.items():
if not is_given(value):
# we don't need to include `NotGiven` values here as they'll
# be stripped out before the request is sent anyway
continue
type_ = annotations.get(key)
if type_ is None:
# we do not have a type annotation for this field, leave it as is
result[key] = value
else:
result[_maybe_transform_key(key, type_)] = _transform_recursive(value, annotation=type_)
return result
async def async_maybe_transform(
data: object,
expected_type: object,
) -> Any | None:
"""Wrapper over `async_transform()` that allows `None` to be passed.
See `async_transform()` for more details.
"""
if data is None:
return None
return await async_transform(data, expected_type)
async def async_transform(
data: _T,
expected_type: object,
) -> _T:
"""Transform dictionaries based off of type information from the given type, for example:
```py
class Params(TypedDict, total=False):
card_id: Required[Annotated[str, PropertyInfo(alias="cardID")]]
transformed = transform({"card_id": "<my card ID>"}, Params)
# {'cardID': '<my card ID>'}
```
Any keys / data that does not have type information given will be included as is.
It should be noted that the transformations that this function does are not represented in the type system.
"""
transformed = await _async_transform_recursive(data, annotation=cast(type, expected_type))
return cast(_T, transformed)
async def _async_transform_recursive(
data: object,
*,
annotation: type,
inner_type: type | None = None,
) -> object:
"""Transform the given data against the expected type.
Args:
annotation: The direct type annotation given to the particular piece of data.
This may or may not be wrapped in metadata types, e.g. `Required[T]`, `Annotated[T, ...]` etc
inner_type: If applicable, this is the "inside" type. This is useful in certain cases where the outside type
is a container type such as `List[T]`. In that case `inner_type` should be set to `T` so that each entry in
the list can be transformed using the metadata from the container type.
Defaults to the same value as the `annotation` argument.
"""
if inner_type is None:
inner_type = annotation
stripped_type = strip_annotated_type(inner_type)
origin = get_origin(stripped_type) or stripped_type
if is_typeddict(stripped_type) and is_mapping(data):
return await _async_transform_typeddict(data, stripped_type)
if origin == dict and is_mapping(data):
items_type = get_args(stripped_type)[1]
return {key: _transform_recursive(value, annotation=items_type) for key, value in data.items()}
if (
# List[T]
(is_list_type(stripped_type) and is_list(data))
# Iterable[T]
or (is_iterable_type(stripped_type) and is_iterable(data) and not isinstance(data, str))
):
# dicts are technically iterable, but it is an iterable on the keys of the dict and is not usually
# intended as an iterable, so we don't transform it.
if isinstance(data, dict):
return cast(object, data)
inner_type = extract_type_arg(stripped_type, 0)
if _no_transform_needed(inner_type):
# for some types there is no need to transform anything, so we can get a small
# perf boost from skipping that work.
#
# but we still need to convert to a list to ensure the data is json-serializable
if is_list(data):
return data
return list(data)
return [await _async_transform_recursive(d, annotation=annotation, inner_type=inner_type) for d in data]
if is_union_type(stripped_type):
# For union types we run the transformation against all subtypes to ensure that everything is transformed.
#
# TODO: there may be edge cases where the same normalized field name will transform to two different names
# in different subtypes.
for subtype in get_args(stripped_type):
data = await _async_transform_recursive(data, annotation=annotation, inner_type=subtype)
return data
if isinstance(data, pydantic.BaseModel):
return model_dump(data, exclude_unset=True, mode="json")
annotated_type = _get_annotated_type(annotation)
if annotated_type is None:
return data
# ignore the first argument as it is the actual type
annotations = get_args(annotated_type)[1:]
for annotation in annotations:
if isinstance(annotation, PropertyInfo) and annotation.format is not None:
return await _async_format_data(data, annotation.format, annotation.format_template)
return data
async def _async_format_data(data: object, format_: PropertyFormat, format_template: str | None) -> object:
if isinstance(data, (date, datetime)):
if format_ == "iso8601":
return data.isoformat()
if format_ == "custom" and format_template is not None:
return data.strftime(format_template)
if format_ == "base64" and is_base64_file_input(data):
binary: str | bytes | None = None
if isinstance(data, pathlib.Path):
binary = await anyio.Path(data).read_bytes()
elif isinstance(data, io.IOBase):
binary = data.read()
if isinstance(binary, str): # type: ignore[unreachable]
binary = binary.encode()
if not isinstance(binary, bytes):
raise RuntimeError(f"Could not read bytes from {data}; Received {type(binary)}")
return base64.b64encode(binary).decode("ascii")
return data
async def _async_transform_typeddict(
data: Mapping[str, object],
expected_type: type,
) -> Mapping[str, object]:
result: dict[str, object] = {}
annotations = get_type_hints(expected_type, include_extras=True)
for key, value in data.items():
if not is_given(value):
# we don't need to include `NotGiven` values here as they'll
# be stripped out before the request is sent anyway
continue
type_ = annotations.get(key)
if type_ is None:
# we do not have a type annotation for this field, leave it as is
result[key] = value
else:
result[_maybe_transform_key(key, type_)] = await _async_transform_recursive(value, annotation=type_)
return result
@lru_cache(maxsize=8096)
def get_type_hints(
obj: Any,
globalns: dict[str, Any] | None = None,
localns: Mapping[str, Any] | None = None,
include_extras: bool = False,
) -> dict[str, Any]:
return _get_type_hints(obj, globalns=globalns, localns=localns, include_extras=include_extras)

View file

@ -0,0 +1,151 @@
from __future__ import annotations
import sys
import typing
import typing_extensions
from typing import Any, TypeVar, Iterable, cast
from collections import abc as _c_abc
from typing_extensions import (
TypeIs,
Required,
Annotated,
get_args,
get_origin,
)
from ._utils import lru_cache
from .._types import InheritsGeneric
from .._compat import is_union as _is_union
def is_annotated_type(typ: type) -> bool:
return get_origin(typ) == Annotated
def is_list_type(typ: type) -> bool:
return (get_origin(typ) or typ) == list
def is_iterable_type(typ: type) -> bool:
"""If the given type is `typing.Iterable[T]`"""
origin = get_origin(typ) or typ
return origin == Iterable or origin == _c_abc.Iterable
def is_union_type(typ: type) -> bool:
return _is_union(get_origin(typ))
def is_required_type(typ: type) -> bool:
return get_origin(typ) == Required
def is_typevar(typ: type) -> bool:
# type ignore is required because type checkers
# think this expression will always return False
return type(typ) == TypeVar # type: ignore
_TYPE_ALIAS_TYPES: tuple[type[typing_extensions.TypeAliasType], ...] = (typing_extensions.TypeAliasType,)
if sys.version_info >= (3, 12):
_TYPE_ALIAS_TYPES = (*_TYPE_ALIAS_TYPES, typing.TypeAliasType)
def is_type_alias_type(tp: Any, /) -> TypeIs[typing_extensions.TypeAliasType]:
"""Return whether the provided argument is an instance of `TypeAliasType`.
```python
type Int = int
is_type_alias_type(Int)
# > True
Str = TypeAliasType("Str", str)
is_type_alias_type(Str)
# > True
```
"""
return isinstance(tp, _TYPE_ALIAS_TYPES)
# Extracts T from Annotated[T, ...] or from Required[Annotated[T, ...]]
@lru_cache(maxsize=8096)
def strip_annotated_type(typ: type) -> type:
if is_required_type(typ) or is_annotated_type(typ):
return strip_annotated_type(cast(type, get_args(typ)[0]))
return typ
def extract_type_arg(typ: type, index: int) -> type:
args = get_args(typ)
try:
return cast(type, args[index])
except IndexError as err:
raise RuntimeError(f"Expected type {typ} to have a type argument at index {index} but it did not") from err
def extract_type_var_from_base(
typ: type,
*,
generic_bases: tuple[type, ...],
index: int,
failure_message: str | None = None,
) -> type:
"""Given a type like `Foo[T]`, returns the generic type variable `T`.
This also handles the case where a concrete subclass is given, e.g.
```py
class MyResponse(Foo[bytes]):
...
extract_type_var(MyResponse, bases=(Foo,), index=0) -> bytes
```
And where a generic subclass is given:
```py
_T = TypeVar('_T')
class MyResponse(Foo[_T]):
...
extract_type_var(MyResponse[bytes], bases=(Foo,), index=0) -> bytes
```
"""
cls = cast(object, get_origin(typ) or typ)
if cls in generic_bases: # pyright: ignore[reportUnnecessaryContains]
# we're given the class directly
return extract_type_arg(typ, index)
# if a subclass is given
# ---
# this is needed as __orig_bases__ is not present in the typeshed stubs
# because it is intended to be for internal use only, however there does
# not seem to be a way to resolve generic TypeVars for inherited subclasses
# without using it.
if isinstance(cls, InheritsGeneric):
target_base_class: Any | None = None
for base in cls.__orig_bases__:
if base.__origin__ in generic_bases:
target_base_class = base
break
if target_base_class is None:
raise RuntimeError(
"Could not find the generic base class;\n"
"This should never happen;\n"
f"Does {cls} inherit from one of {generic_bases} ?"
)
extracted = extract_type_arg(target_base_class, index)
if is_typevar(extracted):
# If the extracted type argument is itself a type variable
# then that means the subclass itself is generic, so we have
# to resolve the type argument from the class itself, not
# the base class.
#
# Note: if there is more than 1 type argument, the subclass could
# change the ordering of the type arguments, this is not currently
# supported.
return extract_type_arg(typ, index)
return extracted
raise RuntimeError(failure_message or f"Could not resolve inner type variable at index {index} for {typ}")

422
src/tinker/_utils/_utils.py Normal file
View file

@ -0,0 +1,422 @@
from __future__ import annotations
import os
import re
import inspect
import functools
from typing import (
Any,
Tuple,
Mapping,
TypeVar,
Callable,
Iterable,
Sequence,
cast,
overload,
)
from pathlib import Path
from datetime import date, datetime
from typing_extensions import TypeGuard
import sniffio
from .._types import NotGiven, FileTypes, NotGivenOr, HeadersLike
from .._compat import parse_date as parse_date, parse_datetime as parse_datetime
_T = TypeVar("_T")
_TupleT = TypeVar("_TupleT", bound=Tuple[object, ...])
_MappingT = TypeVar("_MappingT", bound=Mapping[str, object])
_SequenceT = TypeVar("_SequenceT", bound=Sequence[object])
CallableT = TypeVar("CallableT", bound=Callable[..., Any])
def flatten(t: Iterable[Iterable[_T]]) -> list[_T]:
return [item for sublist in t for item in sublist]
def extract_files(
# TODO: this needs to take Dict but variance issues.....
# create protocol type ?
query: Mapping[str, object],
*,
paths: Sequence[Sequence[str]],
) -> list[tuple[str, FileTypes]]:
"""Recursively extract files from the given dictionary based on specified paths.
A path may look like this ['foo', 'files', '<array>', 'data'].
Note: this mutates the given dictionary.
"""
files: list[tuple[str, FileTypes]] = []
for path in paths:
files.extend(_extract_items(query, path, index=0, flattened_key=None))
return files
def _extract_items(
obj: object,
path: Sequence[str],
*,
index: int,
flattened_key: str | None,
) -> list[tuple[str, FileTypes]]:
try:
key = path[index]
except IndexError:
if isinstance(obj, NotGiven):
# no value was provided - we can safely ignore
return []
# cyclical import
from .._files import assert_is_file_content
# We have exhausted the path, return the entry we found.
assert flattened_key is not None
if is_list(obj):
files: list[tuple[str, FileTypes]] = []
for entry in obj:
assert_is_file_content(entry, key=flattened_key + "[]" if flattened_key else "")
files.append((flattened_key + "[]", cast(FileTypes, entry)))
return files
assert_is_file_content(obj, key=flattened_key)
return [(flattened_key, cast(FileTypes, obj))]
index += 1
if is_dict(obj):
try:
# We are at the last entry in the path so we must remove the field
if (len(path)) == index:
item = obj.pop(key)
else:
item = obj[key]
except KeyError:
# Key was not present in the dictionary, this is not indicative of an error
# as the given path may not point to a required field. We also do not want
# to enforce required fields as the API may differ from the spec in some cases.
return []
if flattened_key is None:
flattened_key = key
else:
flattened_key += f"[{key}]"
return _extract_items(
item,
path,
index=index,
flattened_key=flattened_key,
)
elif is_list(obj):
if key != "<array>":
return []
return flatten(
[
_extract_items(
item,
path,
index=index,
flattened_key=flattened_key + "[]" if flattened_key is not None else "[]",
)
for item in obj
]
)
# Something unexpected was passed, just ignore it.
return []
def is_given(obj: NotGivenOr[_T]) -> TypeGuard[_T]:
return not isinstance(obj, NotGiven)
# Type safe methods for narrowing types with TypeVars.
# The default narrowing for isinstance(obj, dict) is dict[unknown, unknown],
# however this cause Pyright to rightfully report errors. As we know we don't
# care about the contained types we can safely use `object` in it's place.
#
# There are two separate functions defined, `is_*` and `is_*_t` for different use cases.
# `is_*` is for when you're dealing with an unknown input
# `is_*_t` is for when you're narrowing a known union type to a specific subset
def is_tuple(obj: object) -> TypeGuard[tuple[object, ...]]:
return isinstance(obj, tuple)
def is_tuple_t(obj: _TupleT | object) -> TypeGuard[_TupleT]:
return isinstance(obj, tuple)
def is_sequence(obj: object) -> TypeGuard[Sequence[object]]:
return isinstance(obj, Sequence)
def is_sequence_t(obj: _SequenceT | object) -> TypeGuard[_SequenceT]:
return isinstance(obj, Sequence)
def is_mapping(obj: object) -> TypeGuard[Mapping[str, object]]:
return isinstance(obj, Mapping)
def is_mapping_t(obj: _MappingT | object) -> TypeGuard[_MappingT]:
return isinstance(obj, Mapping)
def is_dict(obj: object) -> TypeGuard[dict[object, object]]:
return isinstance(obj, dict)
def is_list(obj: object) -> TypeGuard[list[object]]:
return isinstance(obj, list)
def is_iterable(obj: object) -> TypeGuard[Iterable[object]]:
return isinstance(obj, Iterable)
def deepcopy_minimal(item: _T) -> _T:
"""Minimal reimplementation of copy.deepcopy() that will only copy certain object types:
- mappings, e.g. `dict`
- list
This is done for performance reasons.
"""
if is_mapping(item):
return cast(_T, {k: deepcopy_minimal(v) for k, v in item.items()})
if is_list(item):
return cast(_T, [deepcopy_minimal(entry) for entry in item])
return item
# copied from https://github.com/Rapptz/RoboDanny
def human_join(seq: Sequence[str], *, delim: str = ", ", final: str = "or") -> str:
size = len(seq)
if size == 0:
return ""
if size == 1:
return seq[0]
if size == 2:
return f"{seq[0]} {final} {seq[1]}"
return delim.join(seq[:-1]) + f" {final} {seq[-1]}"
def quote(string: str) -> str:
"""Add single quotation marks around the given string. Does *not* do any escaping."""
return f"'{string}'"
def required_args(*variants: Sequence[str]) -> Callable[[CallableT], CallableT]:
"""Decorator to enforce a given set of arguments or variants of arguments are passed to the decorated function.
Useful for enforcing runtime validation of overloaded functions.
Example usage:
```py
@overload
def foo(*, a: str) -> str: ...
@overload
def foo(*, b: bool) -> str: ...
# This enforces the same constraints that a static type checker would
# i.e. that either a or b must be passed to the function
@required_args(["a"], ["b"])
def foo(*, a: str | None = None, b: bool | None = None) -> str: ...
```
"""
def inner(func: CallableT) -> CallableT:
params = inspect.signature(func).parameters
positional = [
name
for name, param in params.items()
if param.kind
in {
param.POSITIONAL_ONLY,
param.POSITIONAL_OR_KEYWORD,
}
]
@functools.wraps(func)
def wrapper(*args: object, **kwargs: object) -> object:
given_params: set[str] = set()
for i, _ in enumerate(args):
try:
given_params.add(positional[i])
except IndexError:
raise TypeError(
f"{func.__name__}() takes {len(positional)} argument(s) but {len(args)} were given"
) from None
for key in kwargs.keys():
given_params.add(key)
for variant in variants:
matches = all((param in given_params for param in variant))
if matches:
break
else: # no break
if len(variants) > 1:
variations = human_join(
["(" + human_join([quote(arg) for arg in variant], final="and") + ")" for variant in variants]
)
msg = f"Missing required arguments; Expected either {variations} arguments to be given"
else:
assert len(variants) > 0
# TODO: this error message is not deterministic
missing = list(set(variants[0]) - given_params)
if len(missing) > 1:
msg = f"Missing required arguments: {human_join([quote(arg) for arg in missing])}"
else:
msg = f"Missing required argument: {quote(missing[0])}"
raise TypeError(msg)
return func(*args, **kwargs)
return wrapper # type: ignore
return inner
_K = TypeVar("_K")
_V = TypeVar("_V")
@overload
def strip_not_given(obj: None) -> None: ...
@overload
def strip_not_given(obj: Mapping[_K, _V | NotGiven]) -> dict[_K, _V]: ...
@overload
def strip_not_given(obj: object) -> object: ...
def strip_not_given(obj: object | None) -> object:
"""Remove all top-level keys where their values are instances of `NotGiven`"""
if obj is None:
return None
if not is_mapping(obj):
return obj
return {key: value for key, value in obj.items() if not isinstance(value, NotGiven)}
def coerce_integer(val: str) -> int:
return int(val, base=10)
def coerce_float(val: str) -> float:
return float(val)
def coerce_boolean(val: str) -> bool:
return val == "true" or val == "1" or val == "on"
def maybe_coerce_integer(val: str | None) -> int | None:
if val is None:
return None
return coerce_integer(val)
def maybe_coerce_float(val: str | None) -> float | None:
if val is None:
return None
return coerce_float(val)
def maybe_coerce_boolean(val: str | None) -> bool | None:
if val is None:
return None
return coerce_boolean(val)
def removeprefix(string: str, prefix: str) -> str:
"""Remove a prefix from a string.
Backport of `str.removeprefix` for Python < 3.9
"""
if string.startswith(prefix):
return string[len(prefix) :]
return string
def removesuffix(string: str, suffix: str) -> str:
"""Remove a suffix from a string.
Backport of `str.removesuffix` for Python < 3.9
"""
if string.endswith(suffix):
return string[: -len(suffix)]
return string
def file_from_path(path: str) -> FileTypes:
contents = Path(path).read_bytes()
file_name = os.path.basename(path)
return (file_name, contents)
def get_required_header(headers: HeadersLike, header: str) -> str:
lower_header = header.lower()
if is_mapping_t(headers):
# mypy doesn't understand the type narrowing here
for k, v in headers.items(): # type: ignore
if k.lower() == lower_header and isinstance(v, str):
return v
# to deal with the case where the header looks like Stainless-Event-Id
intercaps_header = re.sub(r"([^\w])(\w)", lambda pat: pat.group(1) + pat.group(2).upper(), header.capitalize())
for normalized_header in [header, lower_header, header.upper(), intercaps_header]:
value = headers.get(normalized_header)
if value:
return value
raise ValueError(f"Could not find {header} header")
def get_async_library() -> str:
try:
return sniffio.current_async_library()
except Exception:
return "false"
def lru_cache(*, maxsize: int | None = 128) -> Callable[[CallableT], CallableT]:
"""A version of functools.lru_cache that retains the type signature
for the wrapped function arguments.
"""
wrapper = functools.lru_cache( # noqa: TID251
maxsize=maxsize,
)
return cast(Any, wrapper) # type: ignore[no-any-return]
def json_safe(data: object) -> object:
"""Translates a mapping / sequence recursively in the same fashion
as `pydantic` v2's `model_dump(mode="json")`.
"""
if is_mapping(data):
return {json_safe(key): json_safe(value) for key, value in data.items()}
if is_iterable(data) and not isinstance(data, (str, bytes, bytearray)):
return [json_safe(item) for item in data]
if isinstance(data, (datetime, date)):
return data.isoformat()
return data

4
src/tinker/_version.py Normal file
View file

@ -0,0 +1,4 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
__title__ = "tinker"
__version__ = "0.0.1-alpha.1"

4
src/tinker/lib/.keep Normal file
View file

@ -0,0 +1,4 @@
File generated from our OpenAPI spec by Stainless.
This directory can be used to store custom files to expand the SDK.
It is ignored by Stainless code generation and its content (other than this keep file) won't be touched.

View file

@ -0,0 +1,266 @@
from __future__ import annotations
import asyncio
import contextlib
import inspect
import logging
import time
import traceback
from abc import ABC, abstractmethod
from enum import Enum
from typing import TYPE_CHECKING, Any, Callable, List, Type, TypeVar, cast
import tinker
from tinker import types
from tinker.lib.client_connection_pool_type import ClientConnectionPoolType
from tinker.lib.public_interfaces.api_future import APIFuture
from tinker.lib.telemetry import Telemetry
from .._models import BaseModel
from .retryable_exception import RetryableException
from .sync_only import sync_only
if TYPE_CHECKING:
from tinker.lib.internal_client_holder import InternalClientHolder
logger = logging.getLogger(__name__)
T = TypeVar("T")
U = TypeVar("U")
# Sentinel object to indicate that the function hasn't been called yet
_UNCOMPUTED = object()
class QueueState(Enum):
ACTIVE = "active"
PAUSED_RATE_LIMIT = "paused_rate_limit"
PAUSED_CAPACITY = "paused_capacity"
UNKNOWN = "unknown"
class QueueStateObserver(ABC):
@abstractmethod
def on_queue_state_change(self, queue_state: QueueState) -> None:
raise NotImplementedError
class _APIFuture(APIFuture[T]): # pyright: ignore[reportUnusedClass]
def __init__(
self,
model_cls: Type[T],
holder: InternalClientHolder,
untyped_future: types.UntypedAPIFuture,
request_start_time: float,
request_type: str,
queue_state_observer: QueueStateObserver | None = None,
):
self.model_cls = model_cls
self.holder = holder
self.untyped_future = untyped_future
self.request_type = request_type
self._cached_result: Any = _UNCOMPUTED
# This helps us collect telemetry about how long (1) it takes the
# client to serialize the request, (2) round-trip time to the server
# and back, and (3) how long the server takes to process the request.
# We send this delta in a header to the server when retrieving the promise
# result.
self.request_start_time = request_start_time
self.request_future_start_time = time.time()
self.request_queue_roundtrip_time = self.request_future_start_time - request_start_time
self._future = self.holder.run_coroutine_threadsafe(self._result_async())
self._queue_state_observer: QueueStateObserver | None = queue_state_observer
async def _result_async(self, timeout: float | None = None) -> T:
"""Get the result of this future, with automatic retries for transient errors."""
if self._cached_result is not _UNCOMPUTED:
return cast(T, self._cached_result)
start_time = time.time()
iteration = -1
connection_error_retries = 0
while True:
iteration += 1
if timeout is not None and time.time() - start_time > timeout:
if telemetry := self.get_telemetry():
current_time = time.time()
telemetry.log(
"APIFuture.result_async.timeout",
event_data={
"request_id": self.request_id,
"request_type": self.request_type,
"timeout": timeout,
"iteration": iteration,
"elapsed_time": current_time - start_time,
},
severity="ERROR",
)
raise TimeoutError(
f"Timeout of {timeout} seconds reached while waiting for result of {self.request_id=}"
)
# Headers for telemetry
headers = {
"X-Tinker-Request-Iteration": str(iteration),
"X-Tinker-Request-Type": self.request_type,
}
if iteration == 0:
headers["X-Tinker-Create-Promise-Roundtrip-Time"] = str(
self.request_queue_roundtrip_time
)
# Function hasn't been called yet, execute it now
try:
with self.holder.aclient(ClientConnectionPoolType.RETRIEVE_PROMISE) as client:
response = await client.futures.with_raw_response.retrieve(
request_id=self.request_id, timeout=45, extra_headers=headers, max_retries=0
)
except tinker.APIStatusError as e:
connection_error_retries = 0
should_retry = e.status_code == 408 or e.status_code in range(500, 600)
if telemetry := self.get_telemetry():
current_time = time.time()
telemetry.log(
"APIFuture.result_async.api_status_error",
event_data={
"request_id": self.request_id,
"request_type": self.request_type,
"status_code": e.status_code,
"exception": str(e),
"should_retry": should_retry,
"iteration": iteration,
"elapsed_time": current_time - start_time,
},
severity="WARNING" if should_retry else "ERROR",
)
# Retry 408s until we time out
if e.status_code == 408:
if self._queue_state_observer is not None:
with contextlib.suppress(Exception):
response = e.response.json()
if queue_state_str := response.get("queue_state", None):
if queue_state_str == "active":
queue_state = QueueState.ACTIVE
elif queue_state_str == "paused_rate_limit":
queue_state = QueueState.PAUSED_RATE_LIMIT
elif queue_state_str == "paused_capacity":
queue_state = QueueState.PAUSED_CAPACITY
else:
queue_state = QueueState.UNKNOWN
self._queue_state_observer.on_queue_state_change(
queue_state
)
continue
if e.status_code == 410:
raise RetryableException(
message=f"Promise expired/broken for request {self.untyped_future.request_id}"
) from e
if e.status_code in range(500, 600):
continue
raise ValueError(
f"Error retrieving result: {e} with status code {e.status_code=} for {self.request_id=} and expected type {self.model_cls=}"
) from e
except tinker.APIConnectionError as e:
if telemetry := self.get_telemetry():
current_time = time.time()
telemetry.log(
"APIFuture.result_async.connection_error",
event_data={
"request_id": self.request_id,
"request_type": self.request_type,
"exception": str(e),
"connection_error_retries": connection_error_retries,
"iteration": iteration,
"elapsed_time": current_time - start_time,
},
severity="WARNING",
)
# Retry all connection errors with exponential backoff
await asyncio.sleep(min(2**connection_error_retries, 30))
connection_error_retries += 1
continue
# Function hasn't been called yet, execute it now
result_dict: Any = await response.json()
if "type" in result_dict and result_dict["type"] == "try_again":
logger.warning(f"Retrying request {self.request_id=} because of try_again")
continue
if "error" in result_dict:
raise ValueError(
f"Error retrieving result: {result_dict} for {self.request_id=} and expected type {self.model_cls=}"
)
try:
# Check if model_cls is a BaseModel subclass before calling model_validate
if inspect.isclass(self.model_cls) and issubclass(self.model_cls, BaseModel):
self._cached_result = self.model_cls.model_validate(result_dict)
else:
# For non-BaseModel types, just return the result directly
self._cached_result = result_dict
return cast(T, self._cached_result)
except Exception as e:
if telemetry := self.get_telemetry():
current_time = time.time()
telemetry.log(
"APIFuture.result_async.validation_error",
event_data={
"request_id": self.request_id,
"request_type": self.request_type,
"exception": str(e),
"exception_type": type(e).__name__,
"exception_stack": "".join(
traceback.format_exception(type(e), e, e.__traceback__)
)
if e.__traceback__
else None,
"model_cls": str(self.model_cls),
"iteration": iteration,
"elapsed_time": current_time - start_time,
},
severity="ERROR",
)
raise ValueError(
f"Error retrieving result: {e} for {self.request_id=} and expected type {self.model_cls=}"
) from e
@property
def request_id(self) -> str:
return self.untyped_future.request_id
@sync_only
def result(self, timeout: float | None = None) -> T:
return self._future.result(timeout)
async def result_async(self, timeout: float | None = None) -> T:
return await asyncio.wait_for(self._future, timeout)
def get_telemetry(self) -> Telemetry | None:
return self.holder.get_telemetry()
class _CombinedAPIFuture(APIFuture[T]): # pyright: ignore[reportUnusedClass]
def __init__(
self,
futures: List[APIFuture[T]],
transform: Callable[[List[T]], T],
holder: InternalClientHolder,
):
self.futures = futures
self.transform = transform
self.holder = holder
@sync_only
def result(self, timeout: float | None = None) -> T:
return self.holder.run_coroutine_threadsafe(self.result_async(timeout)).result()
async def result_async(self, timeout: float | None = None) -> T:
results = await asyncio.gather(*[future.result_async(timeout) for future in self.futures])
return self.transform(results)

View file

@ -0,0 +1,28 @@
from __future__ import annotations
import asyncio
from collections.abc import Coroutine
from contextlib import AbstractContextManager
from typing import Any, Protocol, TypeVar
from tinker._client import AsyncTinker
from .public_interfaces.api_future import AwaitableConcurrentFuture
from .client_connection_pool_type import ClientConnectionPoolType
T = TypeVar("T")
class AsyncTinkerProvider(Protocol):
# both of the following methods should be threadsafe
def get_loop(self) -> asyncio.AbstractEventLoop: ...
def run_coroutine_threadsafe(
self,
coro: Coroutine[Any, Any, T],
) -> AwaitableConcurrentFuture[T]: ...
# must be called and used within the provided event loop
def aclient(
self, client_pool_type: ClientConnectionPoolType
) -> AbstractContextManager[AsyncTinker]: ...

View file

@ -0,0 +1,126 @@
import logging
from typing import Any, Dict, List, Sequence, Set, cast
import numpy as np
from tinker.types import ForwardBackwardOutput, LossFnOutput
logger = logging.getLogger(__name__)
Metrics = Dict[str, float]
def combine_fwd_bwd_output_results(
results: Sequence[ForwardBackwardOutput],
) -> ForwardBackwardOutput:
if not results:
return ForwardBackwardOutput(loss_fn_output_type="", metrics={}, loss_fn_outputs=[])
combined_metrics = _metrics_reduction(results)
combined_outputs = _combine_loss_fn_outputs(results)
return ForwardBackwardOutput(
loss_fn_output_type=results[0].loss_fn_output_type,
metrics=combined_metrics,
loss_fn_outputs=combined_outputs,
)
def _combine_loss_fn_outputs(results: Sequence[ForwardBackwardOutput]) -> List[LossFnOutput]:
return [output for result in results for output in result.loss_fn_outputs]
def _order_insensitive_hash(xs: Sequence[Set[Any]] | Sequence[float]) -> int:
"""Combine hash values in an order-insensitive way.
Args:
xs: Either a sequence of sets (original data) or a sequence of already-computed hash values
"""
# If we have sets, flatten and hash them (original behavior)
if xs and isinstance(xs[0], set):
return hash(tuple(sorted([y for x in xs for y in cast(Set[Any], x)])))
# If we have already-computed hash values, combine them deterministically
# Sort them to ensure order-insensitive combination
return hash(tuple(sorted(int(cast(float, x)) for x in xs)))
def _mean(xs: Sequence[float | int], weights: Sequence[float] | None = None) -> float:
if weights is None:
return np.mean(xs).item()
return np.average(xs, weights=weights).item()
def _sum(xs: Sequence[float | int]) -> float:
return np.sum(xs).item()
def _min(xs: Sequence[float | int]) -> float:
return np.min(xs).item()
def _max(xs: Sequence[float | int]) -> float:
return np.max(xs).item()
def _slack(xs: Sequence[float | int], weights: Sequence[float] | None = None) -> float:
if weights is None:
return (np.max(xs) - np.mean(xs)).item()
return (np.max(xs) - np.average(xs, weights=weights)).item()
def _unique(xs: Sequence[float | int]) -> Sequence[float | int]:
"""
A unique metric can't actually fold. But in order to work around the fact that
it's a str:float dict, we just insert unique keys with some suffix for each
unique value. It's a hack, that's for sure.
This is a dummy identity function just for documentation and consistency purposes.
"""
return xs
REDUCE_MAP = {
"mean": _mean,
"sum": _sum,
"min": _min,
"max": _max,
"slack": _slack,
"hash_unordered": _order_insensitive_hash,
"unique": _unique,
}
def _metrics_reduction(results: Sequence[ForwardBackwardOutput]) -> Metrics:
"""Reduce metrics from all actors.
every metric must indicate a reduction_type in its name for example "mfu:mean"
Metrics are weighted by the number of loss_fn_outputs (data points) each actor processed.
"""
if not results:
return {}
keys = results[0].metrics.keys()
weights = [len(m.loss_fn_outputs) for m in results]
res = {}
for key in keys:
name, reduction = key.split(":")
if reduction not in REDUCE_MAP:
# Can happen when a new reduction type is added
logger.debug(
f"Invalid {reduction=} for metric {name=}. Expecting one of {REDUCE_MAP.keys()}"
)
continue
if not all(key in m.metrics for m in results):
continue
reduce_fn = REDUCE_MAP[reduction]
values = [m.metrics[key] for m in results]
if reduction in ["mean", "slack"]:
res[key] = reduce_fn(values, weights)
elif reduction in ["unique"]:
res[key] = values[0]
res.update({f"{key}_{i + 1}": v for i, v in enumerate(values[1:])})
else:
res[key] = reduce_fn(values)
return res

View file

@ -0,0 +1,8 @@
from enum import Enum
class ClientConnectionPoolType(Enum):
SAMPLE = "sample"
TRAIN = "train"
RETRIEVE_PROMISE = "retrieve_promise"
TELEMETRY = "telemetry"

View file

@ -0,0 +1,234 @@
"""Internal client holder for managing AsyncTinker clients."""
from __future__ import annotations
import asyncio
import logging
import threading
import time
import traceback
import uuid
from collections.abc import Coroutine, Generator
from contextlib import AbstractContextManager, contextmanager
from typing import Any, Awaitable, Callable, TypeVar
import httpx
from tinker._client import AsyncTinker
from tinker._exceptions import APIConnectionError, APIStatusError
from tinker.lib.async_tinker_provider import AsyncTinkerProvider
from tinker.lib.client_connection_pool_type import ClientConnectionPoolType
from tinker.lib.public_interfaces.api_future import AwaitableConcurrentFuture
from tinker.lib.telemetry import Telemetry, init_telemetry
from tinker.lib.telemetry_provider import TelemetryProvider
logger = logging.getLogger(__name__)
T = TypeVar("T")
MAX_REQUESTS_PER_HTTPX_CLIENT = 50
class ClientConnectionPool:
def __init__(
self,
loop: asyncio.AbstractEventLoop,
max_requests_per_client: int,
constructor_kwargs: dict[str, Any],
):
self._loop = loop
self._max_requests_per_client = max_requests_per_client
self._constructor_kwargs = constructor_kwargs
self._clients: list[AsyncTinker] = []
self._client_active_refcount: list[int] = []
@contextmanager
def aclient(self) -> Generator[AsyncTinker, None, None]:
assert _current_loop() is self._loop, "AsyncTinker client called from incorrect event loop"
client_idx = -1
for i, ref_count in enumerate(self._client_active_refcount):
if ref_count < self._max_requests_per_client:
client_idx = i
break
if client_idx == -1:
self._clients.append(AsyncTinker(**self._constructor_kwargs))
client_idx = len(self._clients) - 1
self._client_active_refcount.append(0)
self._client_active_refcount[client_idx] += 1
try:
yield self._clients[client_idx]
finally:
self._client_active_refcount[client_idx] -= 1
class InternalClientHolderThreadSingleton:
def __init__(self):
self._loop: asyncio.AbstractEventLoop | None = None
self._thread: threading.Thread | None = None
self._started: bool = False
self._lifecycle_lock: threading.Lock = threading.Lock()
def _ensure_started(self):
if self._started:
return
with self._lifecycle_lock:
if self._started:
return
self._loop = asyncio.new_event_loop()
self._thread = threading.Thread(target=self._background_thread_func, daemon=True)
self._thread.start()
self._started = True
def _background_thread_func(self):
assert self._loop is not None, "Loop must not be None"
self._loop.run_forever()
def get_loop(self) -> asyncio.AbstractEventLoop:
self._ensure_started()
assert self._loop is not None, "Loop must not be None"
return self._loop
_internal_client_holder_thread_singleton = InternalClientHolderThreadSingleton()
class InternalClientHolder(AsyncTinkerProvider, TelemetryProvider):
def __init__(self, **kwargs: Any) -> None:
self._constructor_kwargs = kwargs
# So we can use async eventloop for parallel sampling requests
# in sync code.
self._loop: asyncio.AbstractEventLoop = _internal_client_holder_thread_singleton.get_loop()
self._client_pools: dict[ClientConnectionPoolType, ClientConnectionPool] = {}
self._sample_backoff_until: float | None = None
self._sample_dispatch_semaphore: asyncio.Semaphore = asyncio.Semaphore(400)
self._session_id: str = str(uuid.uuid4())
self._telemetry: Telemetry | None = init_telemetry(self, session_id=self._session_id)
self._training_client_counter: int = 0
self._training_client_lock: threading.Lock = threading.Lock()
self._unordered_id_counter: int = 0
def _get_client_connection_pool(
self, client_pool_type: ClientConnectionPoolType
) -> ClientConnectionPool:
if client_pool_type not in self._client_pools:
max_requests_per_client = (
1
if client_pool_type == ClientConnectionPoolType.TRAIN
else MAX_REQUESTS_PER_HTTPX_CLIENT
)
self._client_pools[client_pool_type] = ClientConnectionPool(
self.get_loop(), max_requests_per_client, self._constructor_kwargs
)
return self._client_pools[client_pool_type]
def get_training_client_id(self) -> int:
with self._training_client_lock:
training_client_id = self._training_client_counter
self._training_client_counter += 1
return training_client_id
def aclient(
self, client_pool_type: ClientConnectionPoolType
) -> AbstractContextManager[AsyncTinker]:
return self._get_client_connection_pool(client_pool_type).aclient()
def get_loop(self) -> asyncio.AbstractEventLoop:
return self._loop
def get_telemetry(self) -> Telemetry | None:
return self._telemetry
def run_coroutine_threadsafe(
self,
coro: Coroutine[Any, Any, T],
) -> AwaitableConcurrentFuture[T]:
return AwaitableConcurrentFuture(asyncio.run_coroutine_threadsafe(coro, self.get_loop()))
def close(self):
if telemetry := self._telemetry:
telemetry.stop()
def __del__(self):
self.close()
def make_training_client_idempotency_key(self, training_client_id: int, request_id: int) -> str:
return f"{self._session_id}:{training_client_id}:{request_id}"
def make_idempotency_key(self) -> str:
self._unordered_id_counter += 1
return f"{self._session_id}:unordered:{self._unordered_id_counter}"
@staticmethod
def _is_retryable_status_code(status_code: int) -> bool:
return status_code in (408, 409, 429) or (500 <= status_code < 600)
@staticmethod
def _is_retryable_exception(exception: Exception) -> bool:
RETRYABLE_EXCEPTIONS = (
asyncio.TimeoutError,
APIConnectionError,
httpx.TimeoutException,
)
if isinstance(exception, RETRYABLE_EXCEPTIONS):
return True
if isinstance(exception, APIStatusError):
return InternalClientHolder._is_retryable_status_code(exception.status_code)
return False
async def execute_with_retries(
self, func: Callable[..., Awaitable[T]], *args: Any, **kwargs: Any
) -> T:
MAX_WAIT_TIME = 60 * 5
start_time = time.time()
attempt_count = 0
while True:
try:
return await func(*args, **kwargs)
except Exception as e:
is_retryable = self._is_retryable_exception(e)
current_time = time.time()
elapsed_time = current_time - start_time
if telemetry := self.get_telemetry():
telemetry.log(
"InternalClientHolder.execute_with_retries.exception",
event_data={
"func": getattr(
func, "__qualname__", getattr(func, "__name__", type(func).__name__)
),
"exception": str(e),
"exception_type": type(e).__name__,
"exception_stack": "".join(
traceback.format_exception(type(e), e, e.__traceback__)
)
if e.__traceback__
else None,
"status_code": getattr(e, "status_code", None),
"is_retryable": is_retryable,
"attempt_count": attempt_count,
"start_time": start_time,
"current_time": current_time,
"elapsed_time": elapsed_time,
},
severity="WARNING" if is_retryable else "ERROR",
)
if is_retryable and elapsed_time < MAX_WAIT_TIME:
# Apply exponential backoff
time_to_wait = min(2**attempt_count, 30)
attempt_count += 1
# Don't wait too long if we're almost at the max wait time
time_to_wait = min(time_to_wait, start_time + MAX_WAIT_TIME - current_time)
await asyncio.sleep(time_to_wait)
continue
raise e
def _current_loop() -> asyncio.AbstractEventLoop | None:
try:
return asyncio.get_running_loop()
except RuntimeError:
return None

View file

@ -0,0 +1,14 @@
"""Public interfaces for the Tinker client library."""
from .api_future import APIFuture, AwaitableConcurrentFuture
from .sampling_client import SamplingClient
from .service_client import ServiceClient
from .training_client import TrainingClient
__all__ = [
"ServiceClient",
"TrainingClient",
"SamplingClient",
"APIFuture",
"AwaitableConcurrentFuture",
]

View file

@ -0,0 +1,40 @@
"""
API Future classes for handling async operations with retry logic.
"""
from __future__ import annotations
import asyncio
from abc import ABC, abstractmethod
from concurrent.futures import Future as ConcurrentFuture
from typing import Generic, TypeVar
T = TypeVar("T")
class APIFuture(ABC, Generic[T]):
@abstractmethod
async def result_async(self, timeout: float | None = None) -> T:
raise NotImplementedError
@abstractmethod
def result(self, timeout: float | None = None) -> T:
raise NotImplementedError
def __await__(self):
return self.result_async().__await__()
class AwaitableConcurrentFuture(APIFuture[T]):
def __init__(self, future: ConcurrentFuture[T]):
self._future: ConcurrentFuture[T] = future
def result(self, timeout: float | None = None) -> T:
return self._future.result(timeout)
async def result_async(self, timeout: float | None = None) -> T:
async with asyncio.timeout(timeout):
return await asyncio.wrap_future(self._future)
def future(self) -> ConcurrentFuture[T]:
return self._future

View file

@ -0,0 +1,397 @@
"""RestClient for Tinker API REST operations."""
from __future__ import annotations
import logging
from concurrent.futures import Future as ConcurrentFuture
from typing import TYPE_CHECKING
from tinker import types, NoneType
from tinker.lib.public_interfaces.api_future import AwaitableConcurrentFuture
from tinker.lib.client_connection_pool_type import ClientConnectionPoolType
from tinker.lib.telemetry import Telemetry, capture_exceptions
from tinker.lib.telemetry_provider import TelemetryProvider
from ..sync_only import sync_only
if TYPE_CHECKING:
from ..internal_client_holder import InternalClientHolder
# pyright: reportPrivateImportUsage=false
logger = logging.getLogger(__name__)
class RestClient(TelemetryProvider):
"""Client for REST API operations like listing checkpoints and metadata.
The RestClient provides access to various REST endpoints for querying
model information, checkpoints, and other resources. You typically get one
by calling `service_client.create_rest_client()`.
Key methods:
- list_checkpoints() - list available model checkpoints (both training and sampler)
- get_training_run() - get model information and metadata as ModelEntry
- delete_checkpoint() - delete an existing checkpoint for a training run
- download_sampler_weights_archive() - download sampler weights checkpoint as tar.gz archive
Args:
holder: Internal client managing HTTP connections and async operations
Example:
>>> rest_client = service_client.create_rest_client()
>>> training_run = rest_client.get_training_run("run-id").result()
>>> print(f"Training Run: {training_run.training_run_id}, LoRA: {training_run.is_lora}")
>>> checkpoints = rest_client.list_checkpoints("run-id").result()
>>> print(f"Found {len(checkpoints.checkpoints)} checkpoints")
>>> for checkpoint in checkpoints.checkpoints:
... print(f" {checkpoint.checkpoint_type}: {checkpoint.checkpoint_id}")
"""
def __init__(self, holder: InternalClientHolder):
self.holder = holder
def _get_training_run_submit(
self, training_run_id: types.ModelID
) -> AwaitableConcurrentFuture[types.TrainingRun]:
"""Internal method to submit get model request."""
async def _get_training_run_async() -> types.TrainingRun:
async def _send_request() -> types.TrainingRun:
with self.holder.aclient(ClientConnectionPoolType.TRAIN) as client:
return await client.get(
f"/api/v1/training_runs/{training_run_id}",
cast_to=types.TrainingRun,
)
return await self.holder.execute_with_retries(_send_request)
return self.holder.run_coroutine_threadsafe(_get_training_run_async())
@sync_only
@capture_exceptions(fatal=True)
def get_training_run(self, training_run_id: types.ModelID) -> ConcurrentFuture[types.TrainingRun]:
"""Get training run info.
Args:
training_run_id: The training run ID to get information for
Returns:
A Future containing the training run information
Example:
>>> future = rest_client.get_training_run("run-id")
>>> response = future.result()
>>> print(f"Training Run ID: {response.training_run_id}, Base: {response.base_model}")
"""
return self._get_training_run_submit(training_run_id).future()
@capture_exceptions(fatal=True)
async def get_training_run_async(self, training_run_id: types.ModelID) -> types.TrainingRun:
"""Async version of get_training_run.
Args:
training_run_id: The training run ID to get information for
Returns:
Training run information
Example:
>>> response = await rest_client.get_training_run_async("run-id")
>>> print(f"Training Run ID: {response.training_run_id}, Base: {response.base_model}")
"""
return await self._get_training_run_submit(training_run_id)
@sync_only
@capture_exceptions(fatal=True)
def get_training_run_by_tinker_path(self, tinker_path: str) -> ConcurrentFuture[types.TrainingRun]:
"""Get training run info.
Args:
tinker_path: The tinker path to the checkpoint
Returns:
A Future containing the training run information
Example:
>>> future = rest_client.get_training_run_by_tinker_path("tinker://run-id/weights/checkpoint-001")
>>> response = future.result()
>>> print(f"Training Run ID: {response.training_run_id}, Base: {response.base_model}")
"""
parsed_checkpoint_tinker_path = types.ParsedCheckpointTinkerPath.from_tinker_path(
tinker_path
)
return self.get_training_run(parsed_checkpoint_tinker_path.training_run_id)
@capture_exceptions(fatal=True)
async def get_training_run_by_tinker_path_async(self, tinker_path: str) -> types.TrainingRun:
"""Async version of get_training_run.
Args:
tinker_path: The tinker path to the checkpoint
Returns:
Training run information
Example:
>>> response = await rest_client.get_training_run_by_tinker_path_async("tinker://run-id/weights/checkpoint-001")
>>> print(f"Training Run ID: {response.training_run_id}, Base: {response.base_model}")
"""
parsed_checkpoint_tinker_path = types.ParsedCheckpointTinkerPath.from_tinker_path(
tinker_path
)
return await self.get_training_run_async(parsed_checkpoint_tinker_path.training_run_id)
def _list_training_runs_submit(
self, limit: int = 20, offset: int = 0
) -> AwaitableConcurrentFuture[types.TrainingRunsResponse]:
"""Internal method to submit list training runs request."""
async def _list_training_runs_async() -> types.TrainingRunsResponse:
async def _send_request() -> types.TrainingRunsResponse:
with self.holder.aclient(ClientConnectionPoolType.TRAIN) as client:
params: dict[str, object] = {"limit": limit, "offset": offset}
return await client.get(
"/api/v1/training_runs",
options={"params": params},
cast_to=types.TrainingRunsResponse,
)
return await self.holder.execute_with_retries(_send_request)
return self.holder.run_coroutine_threadsafe(_list_training_runs_async())
@sync_only
@capture_exceptions(fatal=True)
def list_training_runs(
self, limit: int = 20, offset: int = 0
) -> ConcurrentFuture[types.TrainingRunsResponse]:
"""List training runs with pagination support.
Args:
limit: Maximum number of training runs to return (default 20)
offset: Offset for pagination (default 0)
Returns:
A Future containing the TrainingRunsResponse with training runs and cursor info
Example:
>>> future = rest_client.list_training_runs(limit=50)
>>> response = future.result()
>>> print(f"Found {len(response.training_runs)} training runs")
>>> print(f"Total: {response.cursor.total_count}")
>>> # Get next page
>>> next_page = rest_client.list_training_runs(limit=50, offset=50)
"""
return self._list_training_runs_submit(limit, offset).future()
@capture_exceptions(fatal=True)
async def list_training_runs_async(
self, limit: int = 20, offset: int = 0
) -> types.TrainingRunsResponse:
"""Async version of list_training_runs.
Args:
limit: Maximum number of training runs to return (default 20)
offset: Offset for pagination (default 0)
Returns:
TrainingRunsResponse with training runs and cursor info
Example:
>>> response = await rest_client.list_training_runs_async(limit=50)
>>> print(f"Found {len(response.training_runs)} training runs")
>>> print(f"Total: {response.cursor.total_count}")
>>> # Get next page
>>> next_page = await rest_client.list_training_runs_async(limit=50, offset=50)
"""
return await self._list_training_runs_submit(limit, offset)
def _list_checkpoints_submit(
self, training_run_id: types.ModelID
) -> AwaitableConcurrentFuture[types.CheckpointsListResponse]:
"""Internal method to submit list model checkpoints request."""
async def _list_checkpoints_async():
async def _send_request():
with self.holder.aclient(ClientConnectionPoolType.TRAIN) as client:
return await client.weights.list(training_run_id)
return await self.holder.execute_with_retries(_send_request)
return self.holder.run_coroutine_threadsafe(_list_checkpoints_async())
@sync_only
@capture_exceptions(fatal=True)
def list_checkpoints(self, training_run_id: types.ModelID) -> ConcurrentFuture[types.CheckpointsListResponse]:
"""List available checkpoints (both training and sampler).
Args:
training_run_id: The training run ID to list checkpoints for
Returns:
A Future containing the CheckpointsListResponse with available checkpoints
Example:
>>> future = rest_client.list_checkpoints("run-id")
>>> response = future.result()
>>> for checkpoint in response.checkpoints:
... if checkpoint.checkpoint_type == "training":
... print(f"Training checkpoint: {checkpoint.checkpoint_id}")
... elif checkpoint.checkpoint_type == "sampler":
... print(f"Sampler checkpoint: {checkpoint.checkpoint_id}")
"""
return self._list_checkpoints_submit(training_run_id).future()
@capture_exceptions(fatal=True)
async def list_checkpoints_async(self, training_run_id: types.ModelID) -> types.CheckpointsListResponse:
"""Async version of list_checkpoints.
Args:
training_run_id: The training run ID to list checkpoints for
Returns:
CheckpointsListResponse with available checkpoints
Example:
>>> response = await rest_client.list_checkpoints_async("run-id")
>>> for checkpoint in response.checkpoints:
... if checkpoint.checkpoint_type == "training":
... print(f"Training checkpoint: {checkpoint.checkpoint_id}")
... elif checkpoint.checkpoint_type == "sampler":
... print(f"Sampler checkpoint: {checkpoint.checkpoint_id}")
"""
return await self._list_checkpoints_submit(training_run_id)
def _download_checkpoint_archive_submit(
self, training_run_id: types.ModelID, checkpoint_id: str
) -> AwaitableConcurrentFuture[bytes]:
"""Internal method to submit download checkpoint archive request."""
async def _download_checkpoint_archive_async():
async def _send_request():
with self.holder.aclient(ClientConnectionPoolType.TRAIN) as client:
return await client.get(
f"/api/v1/training_runs/{training_run_id}/checkpoints/{checkpoint_id}/archive",
cast_to=bytes,
options={"headers": {"accept": "application/gzip"}},
)
return await self.holder.execute_with_retries(_send_request)
return self.holder.run_coroutine_threadsafe(_download_checkpoint_archive_async())
@sync_only
@capture_exceptions(fatal=True)
def download_checkpoint_archive(
self, training_run_id: types.ModelID, checkpoint_id: str
) -> ConcurrentFuture[bytes]:
"""Download checkpoint as a tar.gz archive.
Args:
training_run_id: The training run ID to download weights for
checkpoint_id: The checkpoint ID to download
Returns:
A Future containing the archive data as bytes
Example:
>>> future = rest_client.download_checkpoint_archive("run-id", "checkpoint-123")
>>> archive_data = future.result()
>>> with open(f"model-checkpoint.tar.gz", "wb") as f:
... f.write(archive_data)
"""
return self._download_checkpoint_archive_submit(training_run_id, checkpoint_id).future()
@capture_exceptions(fatal=True)
async def download_checkpoint_archive_async(
self, training_run_id: types.ModelID, checkpoint_id: str
) -> bytes:
"""Async version of download_checkpoint_archive.
Args:
training_run_id: The model ID to download weights for
checkpoint_id: The checkpoint ID to download
Returns:
Archive data as bytes
Example:
>>> archive_data = await rest_client.download_checkpoint_archive_async("run-id", "checkpoint-123")
>>> with open(f"model-checkpoint.tar.gz", "wb") as f:
... f.write(archive_data)
"""
return await self._download_checkpoint_archive_submit(training_run_id, checkpoint_id)
def _delete_checkpoint_submit(
self, training_run_id: types.ModelID, checkpoint_id: str
) -> AwaitableConcurrentFuture[None]:
"""Internal method to submit delete checkpoint request."""
async def _delete_checkpoint_async() -> None:
async def _send_request() -> None:
with self.holder.aclient(ClientConnectionPoolType.TRAIN) as client:
await client.delete(
f"/api/v1/training_runs/{training_run_id}/checkpoints/{checkpoint_id}",
cast_to=NoneType,
)
return await self.holder.execute_with_retries(_send_request)
return self.holder.run_coroutine_threadsafe(_delete_checkpoint_async())
@sync_only
@capture_exceptions(fatal=True)
def delete_checkpoint(self, training_run_id: types.ModelID, checkpoint_id: str) -> ConcurrentFuture[None]:
"""Delete a checkpoint for a training run."""
return self._delete_checkpoint_submit(training_run_id, checkpoint_id).future()
@capture_exceptions(fatal=True)
async def delete_checkpoint_async(self, training_run_id: types.ModelID, checkpoint_id: str) -> None:
"""Async version of delete_checkpoint."""
await self._delete_checkpoint_submit(training_run_id, checkpoint_id)
@sync_only
@capture_exceptions(fatal=True)
def delete_checkpoint_from_tinker_path(self, tinker_path: str) -> ConcurrentFuture[None]:
"""Delete a checkpoint referenced by a tinker path."""
parsed_tinker_path = types.ParsedCheckpointTinkerPath.from_tinker_path(tinker_path)
return self._delete_checkpoint_submit(parsed_tinker_path.training_run_id, parsed_tinker_path.checkpoint_id).future()
@capture_exceptions(fatal=True)
async def delete_checkpoint_from_tinker_path_async(self, tinker_path: str) -> None:
"""Async version of delete_checkpoint_from_tinker_path."""
parsed_tinker_path = types.ParsedCheckpointTinkerPath.from_tinker_path(tinker_path)
await self._delete_checkpoint_submit(parsed_tinker_path.training_run_id, parsed_tinker_path.checkpoint_id)
def get_telemetry(self) -> Telemetry | None:
return self.holder.get_telemetry()
@sync_only
@capture_exceptions(fatal=True)
def download_checkpoint_archive_from_tinker_path(
self, tinker_path: str
) -> ConcurrentFuture[bytes]:
"""Download checkpoint as a tar.gz archive.
Args:
tinker_path: The tinker path to the checkpoint
Returns:
A Future containing the archive data as bytes
"""
parsed_tinker_path = types.ParsedCheckpointTinkerPath.from_tinker_path(tinker_path)
return self._download_checkpoint_archive_submit(parsed_tinker_path.training_run_id, parsed_tinker_path.checkpoint_id).future()
@capture_exceptions(fatal=True)
async def download_checkpoint_archive_from_tinker_path_async(
self, tinker_path: str
) -> bytes:
"""Async version of download_checkpoint_archive_from_tinker_path.
Args:
tinker_path: The tinker path to the checkpoint
"""
parsed_tinker_path = types.ParsedCheckpointTinkerPath.from_tinker_path(tinker_path)
return await self._download_checkpoint_archive_submit(parsed_tinker_path.training_run_id, parsed_tinker_path.checkpoint_id)

View file

@ -0,0 +1,231 @@
"""SamplingClient for Tinker API."""
from __future__ import annotations
import asyncio
import logging
import os
import time
from collections.abc import Sequence
from concurrent.futures import Future as ConcurrentFuture
from functools import lru_cache
from typing import TYPE_CHECKING, TypeVar, cast
import tinker
from tinker import types
from tinker._types import NOT_GIVEN
from tinker.lib.client_connection_pool_type import ClientConnectionPoolType
from tinker.lib.public_interfaces.api_future import AwaitableConcurrentFuture
from tinker.lib.telemetry import Telemetry, capture_exceptions
from tinker.lib.telemetry_provider import TelemetryProvider
from ..api_future_impl import QueueState, QueueStateObserver, _APIFuture
from ..retry_handler import RetryConfig, RetryHandler
if TYPE_CHECKING:
from ..internal_client_holder import InternalClientHolder
# pyright: reportPrivateImportUsage=false
logger = logging.getLogger(__name__)
U = TypeVar("U")
class SamplingClient(TelemetryProvider, QueueStateObserver):
"""Client for text generation and inference from trained or base models.
The SamplingClient lets you generate text tokens from either a base model or from weights
you've saved using a TrainingClient. You typically get one by calling
`service_client.create_sampling_client()` or `training_client.save_weights_and_get_sampling_client()`.
Key methods:
- sample() - generate text completions with customizable parameters
- compute_logprobs() - get log probabilities for prompt tokens
Args:
holder: Internal client managing HTTP connections and async operations
model_path: Path to saved model weights (starts with 'tinker://')
base_model: Name of base model to use for inference
retry_config: Configuration for retrying failed requests
Example:
>>> sampling_client = service_client.create_sampling_client(base_model="Qwen/Qwen2.5-7B")
>>> prompt = types.ModelInput.from_ints(tokenizer.encode("The weather today is"))
>>> params = types.SamplingParams(max_tokens=20, temperature=0.7)
>>> future = sampling_client.sample(prompt=prompt, sampling_params=params, num_samples=1)
>>> result = future.result()
"""
def __init__(
self,
holder: InternalClientHolder,
*,
model_path: str | None = None,
base_model: str | None = None,
retry_config: RetryConfig | None = None,
):
if model_path and not model_path.startswith("tinker://"):
raise ValueError("model_path must start with 'tinker://'")
self.holder = holder
self.model_path = model_path
self.base_model = base_model
# Create retry handler with the provided configuration
self.retry_handler = _get_retry_handler(
model_path or base_model, retry_config=retry_config, telemetry=holder.get_telemetry()
)
self.feature_gates = set(
os.environ.get("TINKER_FEATURE_GATES", "async_sampling").split(",")
)
self._last_queue_state_logged: float = 0
async def _send_asample_request(
self,
num_samples: int,
prompt: types.ModelInput,
sampling_params: types.SamplingParams,
include_prompt_logprobs: bool,
idempotency_key: str,
):
try:
with self.holder.aclient(ClientConnectionPoolType.SAMPLE) as client:
return await client.sampling.asample(
num_samples=num_samples,
prompt=cast(types._ModelInputParam, prompt.model_dump()),
sampling_params=cast(types._SamplingParamsParam, sampling_params.model_dump()),
model_path=self.model_path if self.model_path is not None else NOT_GIVEN,
prompt_logprobs=include_prompt_logprobs,
base_model=self.base_model if self.base_model is not None else NOT_GIVEN,
max_retries=0,
extra_headers={"X-Tinker-Sampling-Backpressure": "1"},
idempotency_key=idempotency_key,
)
except tinker.APIStatusError as e:
if e.status_code == 429:
return None
raise e
async def _sample_async_impl(
self,
prompt: types.ModelInput,
num_samples: int,
sampling_params: types.SamplingParams,
include_prompt_logprobs: bool,
) -> types.SampleResponse:
idempotency_key = self.holder.make_idempotency_key()
async with self.holder._sample_dispatch_semaphore:
while True:
if (
self.holder._sample_backoff_until is not None
and time.time() < self.holder._sample_backoff_until
):
await asyncio.sleep(1)
continue
untyped_future = await self.holder.execute_with_retries(
self._send_asample_request,
num_samples,
prompt,
sampling_params,
include_prompt_logprobs,
idempotency_key,
)
if untyped_future is not None:
break
# Handle backoff
self.holder._sample_backoff_until = time.time() + 1
continue
return await _APIFuture(
types.SampleResponse,
self.holder,
untyped_future,
request_start_time=time.time(),
request_type="Sample",
queue_state_observer=self,
).result_async()
@capture_exceptions(fatal=True)
def sample(
self,
prompt: types.ModelInput,
num_samples: int,
sampling_params: types.SamplingParams,
include_prompt_logprobs: bool = False,
) -> ConcurrentFuture[types.SampleResponse]:
"""Internal method that does the actual API call without retry logic."""
async def _sample_async():
return await self._sample_async_impl(
prompt, num_samples, sampling_params, include_prompt_logprobs
)
@capture_exceptions(fatal=True)
async def _sample_async_with_retries() -> types.SampleResponse:
return await self.retry_handler.execute(_sample_async)
# TODO make max_tokens a required field
return self.holder.run_coroutine_threadsafe(_sample_async_with_retries()).future()
async def sample_async(
self,
prompt: types.ModelInput,
num_samples: int,
sampling_params: types.SamplingParams,
include_prompt_logprobs: bool = False,
) -> types.SampleResponse:
return await AwaitableConcurrentFuture(
self.sample(prompt, num_samples, sampling_params, include_prompt_logprobs)
)
@capture_exceptions(fatal=True)
def compute_logprobs(
self, prompt: types.ModelInput
) -> ConcurrentFuture[Sequence[float | None]]:
async def _compute_logprobs_async() -> Sequence[float | None]:
sample_res = await self._sample_async_impl(
prompt,
num_samples=1,
sampling_params=types.SamplingParams(max_tokens=1),
include_prompt_logprobs=True,
)
return cast(list[float | None], sample_res.prompt_logprobs)
@capture_exceptions(fatal=True)
async def _compute_logprobs_async_with_retries() -> Sequence[float | None]:
return await self.retry_handler.execute(_compute_logprobs_async)
return self.holder.run_coroutine_threadsafe(_compute_logprobs_async_with_retries()).future()
async def compute_logprobs_async(self, prompt: types.ModelInput) -> Sequence[float | None]:
return await AwaitableConcurrentFuture(self.compute_logprobs(prompt))
def get_telemetry(self) -> Telemetry | None:
return self.holder.get_telemetry()
def on_queue_state_change(self, queue_state: QueueState) -> None:
QUEUE_STATE_LOG_INTERVAL = 60
if queue_state == QueueState.ACTIVE:
return
if time.time() - self._last_queue_state_logged < QUEUE_STATE_LOG_INTERVAL:
return
if queue_state == QueueState.PAUSED_RATE_LIMIT:
reason = "concurrent LoRA rate limit hit"
elif queue_state == QueueState.PAUSED_CAPACITY:
reason = "out of capacity"
else:
reason = "unknown"
self._last_queue_state_logged = time.time()
logger.warning(f"Sampling is paused for {self.model_path}. Reason: {reason}")
@lru_cache(maxsize=100)
def _get_retry_handler(
name: str, retry_config: RetryConfig | None = None, telemetry: Telemetry | None = None
) -> RetryHandler:
retry_config = retry_config or RetryConfig()
return RetryHandler(config=retry_config, name=name, telemetry=telemetry)

View file

@ -0,0 +1,247 @@
"""ServiceClient for Tinker API."""
from __future__ import annotations
import logging
import os
import time
from typing import TYPE_CHECKING, Any, cast
from tinker import types
from tinker.lib.client_connection_pool_type import ClientConnectionPoolType
from tinker.lib.public_interfaces.api_future import AwaitableConcurrentFuture
from tinker.lib.telemetry import Telemetry, capture_exceptions
from tinker.lib.telemetry_provider import TelemetryProvider
from ..api_future_impl import _APIFuture
from ..internal_client_holder import InternalClientHolder
from ..retry_handler import RetryConfig
from ..sync_only import sync_only
if TYPE_CHECKING:
from .rest_client import RestClient
from .sampling_client import SamplingClient
from .training_client import TrainingClient
# pyright: reportPrivateImportUsage=false
logger = logging.getLogger(__name__)
class ServiceClient(TelemetryProvider):
"""The ServiceClient is the main entry point for the Tinker API. It provides methods to:
- Query server capabilities and health status
- Generate TrainingClient instances for model training workflows
- Generate SamplingClient instances for text generation and inference
- Generate RestClient instances for REST API operations like listing weights
Args:
**kwargs: advanced options passed to the underlying HTTP client,
including API keys, headers, and connection settings.
Example:
>>> client = ServiceClient()
# ^^^ near-instant
>>> training_client = client.create_lora_training_client(base_model="Qwen/Qwen3-8B")
# ^^^ takes a moment as we initialize the model and assign resources
>>> sampling_client = client.create_sampling_client(base_model="Qwen/Qwen3-8B")
# ^^^ near-instant
>>> rest_client = client.create_rest_client()
# ^^^ near-instant
"""
def __init__(self, **kwargs: Any):
default_headers = _get_default_headers() | kwargs.pop("default_headers", {})
self.holder = InternalClientHolder(
**kwargs, default_headers=default_headers, _strict_response_validation=True
)
def _get_server_capabilities_submit(
self,
) -> AwaitableConcurrentFuture[types.GetServerCapabilitiesResponse]:
async def _get_server_capabilities_async():
async def _send_request():
with self.holder.aclient(ClientConnectionPoolType.TRAIN) as client:
return await client.service.get_server_capabilities()
return await self.holder.execute_with_retries(_send_request)
return self.holder.run_coroutine_threadsafe(_get_server_capabilities_async())
@sync_only
@capture_exceptions(fatal=True)
def get_server_capabilities(self) -> types.GetServerCapabilitiesResponse:
return self._get_server_capabilities_submit().result()
@capture_exceptions(fatal=True)
async def get_server_capabilities_async(self) -> types.GetServerCapabilitiesResponse:
return await self._get_server_capabilities_submit()
def _create_model_submit(
self, base_model: str, lora_config: types.LoraConfig
) -> AwaitableConcurrentFuture[types.ModelID]:
async def _create_model_async():
start_time = time.time()
with self.holder.aclient(ClientConnectionPoolType.TRAIN) as client:
future = await client.models.create(
base_model=base_model, lora_config=_to_lora_config_params(lora_config)
)
create_model_response = await _APIFuture(
types.CreateModelResponse,
self.holder,
future,
request_start_time=start_time,
request_type="CreateModel",
).result_async()
return create_model_response.model_id
return self.holder.run_coroutine_threadsafe(_create_model_async())
@sync_only
@capture_exceptions(fatal=True)
def create_lora_training_client(
self,
base_model: str,
rank: int = 32,
seed: int | None = None,
train_mlp: bool = True,
train_attn: bool = True,
train_unembed: bool = True,
) -> TrainingClient:
assert any([train_mlp, train_attn, train_unembed]), (
"At least one of train_mlp, train_attn, or train_unembed must be True"
)
model_id = self._create_model_submit(
base_model,
types.LoraConfig(
rank=rank,
seed=seed,
train_mlp=train_mlp,
train_attn=train_attn,
train_unembed=train_unembed,
),
).result()
logger.info(f"Creating TrainingClient for {model_id=}")
return self.create_training_client(model_id)
@capture_exceptions(fatal=True)
async def create_lora_training_client_async(
self,
base_model: str,
rank: int = 32,
seed: int | None = None,
train_mlp: bool = True,
train_attn: bool = True,
train_unembed: bool = True,
) -> TrainingClient:
assert any([train_mlp, train_attn, train_unembed]), (
"At least one of train_mlp, train_attn, or train_unembed must be True"
)
model_id = await self._create_model_submit(
base_model,
types.LoraConfig(
rank=rank,
seed=seed,
train_mlp=train_mlp,
train_attn=train_attn,
train_unembed=train_unembed,
),
)
logger.info(f"Creating TrainingClient for {model_id=}")
return self.create_training_client(model_id)
@capture_exceptions(fatal=True)
def create_training_client(self, model_id: types.ModelID | None = None) -> TrainingClient:
from .training_client import TrainingClient
return TrainingClient(self.holder, model_id=model_id)
@sync_only
@capture_exceptions(fatal=True)
def create_training_client_from_state(self, path: str) -> TrainingClient:
rest_client = self.create_rest_client()
training_run = rest_client.get_training_run_by_tinker_path(path).result()
training_client = self.create_lora_training_client(
base_model=training_run.base_model,
rank=training_run.lora_rank,
)
training_client.load_state(path).result()
return training_client
@capture_exceptions(fatal=True)
async def create_training_client_from_state_async(self, path: str) -> TrainingClient:
rest_client = self.create_rest_client()
training_run = await rest_client.get_training_run_by_tinker_path_async(path)
# Right now all training runs are LoRa runs.
assert training_run.is_lora and training_run.lora_rank is not None
training_client = await self.create_lora_training_client_async(
base_model=training_run.base_model,
rank=training_run.lora_rank,
)
load_future = await training_client.load_state_async(path)
await load_future.result_async()
return training_client
@capture_exceptions(fatal=True)
def create_sampling_client(
self,
model_path: str | None = None,
base_model: str | None = None,
retry_config: RetryConfig | None = None,
) -> SamplingClient:
from .sampling_client import SamplingClient
if model_path is None and base_model is None:
raise ValueError("Either model_path or base_model must be provided")
return SamplingClient(
self.holder,
model_path=model_path,
base_model=base_model,
retry_config=retry_config,
)
@capture_exceptions(fatal=True)
def create_rest_client(self) -> RestClient:
"""Create a RestClient for REST API operations.
Returns:
RestClient: A client for listing weights and other REST operations
Example:
>>> rest_client = service_client.create_rest_client()
>>> weights = rest_client.list_model_weights("my-model-id").result()
"""
from .rest_client import RestClient
return RestClient(self.holder)
def get_telemetry(self) -> Telemetry | None:
return self.holder.get_telemetry()
def _get_default_headers() -> dict[str, str]:
headers = {}
if (api_key := os.environ.get("TINKER_API_KEY", "")) and "X-API-Key" not in headers:
headers["X-API-Key"] = api_key
headers["X-Username"] = os.environ.get("USER", "")
if (
client_id := os.environ.get("CLOUDFLARE_ACCESS_CLIENT_ID")
) and "CF-Access-Client-Id" not in headers:
headers["CF-Access-Client-Id"] = client_id
if (
client_secret := os.environ.get("CLOUDFLARE_ACCESS_CLIENT_SECRET")
) and "CF-Access-Client-Secret" not in headers:
headers["CF-Access-Client-Secret"] = client_secret
return headers
def _to_lora_config_params(x: types.LoraConfig) -> types._LoraConfigParam:
return cast(types._LoraConfigParam, x.model_dump())

View file

@ -0,0 +1,585 @@
"""TrainingClient for Tinker API."""
from __future__ import annotations
import asyncio
import logging
import threading
import time
from contextlib import asynccontextmanager
from functools import cache
from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, List, Tuple, cast
import torch
from tinker import types
from tinker.lib.client_connection_pool_type import ClientConnectionPoolType
from tinker.lib.public_interfaces.api_future import APIFuture, AwaitableConcurrentFuture
from tinker.lib.telemetry import Telemetry, capture_exceptions
from tinker.lib.telemetry_provider import TelemetryProvider
from tinker.types import training_optim_step_params
from ..api_future_impl import (
QueueState,
QueueStateObserver,
_APIFuture,
_CombinedAPIFuture,
)
from ..chunked_fwdbwd_helpers import combine_fwd_bwd_output_results
from ..retry_handler import RetryConfig
from ..sync_only import sync_only
if TYPE_CHECKING:
from transformers.tokenization_utils import PreTrainedTokenizer
from ..internal_client_holder import InternalClientHolder
from .sampling_client import SamplingClient
# pyright: reportPrivateImportUsage=false
logger = logging.getLogger(__name__)
# FwdBwdChunkSize
MAX_CHUNK_LEN = 128
MAX_CHUNK_NUMBER_COUNT = 500000
MODEL_ID_NOT_SET_ERROR = "model_id must be set before calling forward. Try initializing the TrainingClient with a model_id by either calling create_lora_training_client on the ServiceClient, or initiliazing the TrainingClient with an existing model_id."
CustomLossFnV1 = Callable[[List[types.Datum], List[Any]], Tuple[Any, Dict[str, float]]]
class TrainingClient(TelemetryProvider, QueueStateObserver):
"""Client for training ML models with forward/backward passes and optimization.
The TrainingClient corresponds to a fine-tuned model that you can train and sample from.
You typically get one by calling `service_client.create_lora_training_client()`.
Key methods:
- forward_backward() - compute gradients for training
- optim_step() - update model parameters with Adam optimizer
- save_weights_and_get_sampling_client() - export trained model for inference
Args:
holder: Internal client managing HTTP connections and async operations
model_id: Unique identifier for the model to train. Required for training operations.
Example:
>>> training_client = service_client.create_lora_training_client(base_model="Qwen/Qwen2.5-7B")
>>> fwdbwd_future = training_client.forward_backward(training_data, "cross_entropy")
>>> optim_future = training_client.optim_step(types.AdamParams(learning_rate=1e-4))
>>> fwdbwd_result = fwdbwd_future.result() # Wait for gradients
>>> optim_result = optim_future.result() # Wait for parameter update
>>> sampling_client = training_client.save_weights_and_get_sampling_client("my-model")
"""
def __init__(self, holder: InternalClientHolder, model_id: types.ModelID | None = None):
self.holder = holder
self.model_id = model_id
self._training_client_id: int = self.holder.get_training_client_id()
self._request_id_lock: threading.Lock = threading.Lock()
self._request_id_counter: int = 0
self._turn_counter: int = 0
self._turn_waiters: dict[int, asyncio.Event] = {}
self._last_queue_state_logged: float = 0
# Reserves a request id for a request. Requests are to be executed in the order of request ids.
def _get_request_id(self) -> int:
with self._request_id_lock:
request_id = self._request_id_counter
self._request_id_counter += 1
return request_id
# Waits for the turn for a given request id to be executed.
# This has to be used via a with statement so that the turn is released
# only after current request was successfully dispatched.
@asynccontextmanager
async def _take_turn(self, request_id: int):
assert self._turn_counter <= request_id, "Same request id cannot be taken twice"
if self._turn_counter < request_id:
try:
event = asyncio.Event()
self._turn_waiters[request_id] = event
await event.wait()
finally:
del self._turn_waiters[request_id]
assert self._turn_counter == request_id
try:
yield
finally:
self._turn_counter += 1
if self._turn_counter in self._turn_waiters:
self._turn_waiters[self._turn_counter].set()
def _make_idempotency_key(self, request_id: int) -> str:
return self.holder.make_training_client_idempotency_key(
self._training_client_id, request_id
)
def _guaranteed_model_id(self) -> types.ModelID:
assert self.model_id is not None, MODEL_ID_NOT_SET_ERROR
return self.model_id
def _estimate_number_count(self, datum: types.Datum) -> int:
return datum.model_input.length + sum(
len(value.data) for _, value in datum.loss_fn_inputs.items()
)
def _chunked_requests_generator(
self, data: List[types.Datum]
) -> Generator[List[types.Datum], None, None]:
current_chunk: List[types.Datum] = []
current_chunk_number_count = 0
for datum in data:
estimated_number_count = self._estimate_number_count(datum)
if (
len(current_chunk) > 0
and current_chunk_number_count + estimated_number_count > MAX_CHUNK_NUMBER_COUNT
) or (len(current_chunk) == MAX_CHUNK_LEN):
yield current_chunk
current_chunk = []
current_chunk_number_count = 0
current_chunk.append(datum)
current_chunk_number_count += estimated_number_count
if len(current_chunk) > 0:
yield current_chunk
def _chunked_requests(self, data: List[types.Datum]) -> List[tuple[int, List[types.Datum]]]:
return [(self._get_request_id(), chunk) for chunk in self._chunked_requests_generator(data)]
async def _send_single_forward_request(
self, request_id: int, data: List[types.Datum], loss_fn: types.LossFnType
):
with self.holder.aclient(ClientConnectionPoolType.TRAIN) as client:
return await client.training.forward(
model_id=self._guaranteed_model_id(),
forward_input=_to_fwdbwd_input_params(
types.ForwardBackwardInput(data=data, loss_fn=loss_fn)
),
idempotency_key=self._make_idempotency_key(request_id),
)
@capture_exceptions(fatal=True)
def forward(
self, data: List[types.Datum], loss_fn: types.LossFnType
) -> APIFuture[types.ForwardBackwardOutput]:
requests = self._chunked_requests(data)
@capture_exceptions(fatal=True)
async def _forward_async():
start_time = time.time()
futures = []
for request_id, data in requests:
async with self._take_turn(request_id):
untyped_future = await self.holder.execute_with_retries(
self._send_single_forward_request, request_id, data, loss_fn
)
api_future = _APIFuture(
types.ForwardBackwardOutput,
self.holder,
untyped_future,
request_start_time=start_time,
request_type="Forward",
queue_state_observer=self,
)
futures.append(api_future)
return await _CombinedAPIFuture(futures, combine_fwd_bwd_output_results, self.holder)
return self.holder.run_coroutine_threadsafe(_forward_async())
async def forward_async(
self, data: List[types.Datum], loss_fn: types.LossFnType
) -> APIFuture[types.ForwardBackwardOutput]:
return self.forward(data, loss_fn)
async def _send_single_forward_backward_request(
self, request_id: int, data: List[types.Datum], loss_fn: types.LossFnType
):
with self.holder.aclient(ClientConnectionPoolType.TRAIN) as client:
return await client.training.forward_backward(
model_id=self._guaranteed_model_id(),
forward_backward_input=_to_fwdbwd_input_params(
types.ForwardBackwardInput(data=data, loss_fn=loss_fn)
),
idempotency_key=self._make_idempotency_key(request_id),
)
@capture_exceptions(fatal=True)
def forward_backward(
self, data: List[types.Datum], loss_fn: types.LossFnType
) -> APIFuture[types.ForwardBackwardOutput]:
requests = self._chunked_requests(data)
@capture_exceptions(fatal=True)
async def _forward_backward_async():
futures = []
start_time = time.time()
for request_id, data in requests:
async with self._take_turn(request_id):
untyped_future = await self.holder.execute_with_retries(
self._send_single_forward_backward_request, request_id, data, loss_fn
)
api_future = _APIFuture(
types.ForwardBackwardOutput,
self.holder,
untyped_future,
request_start_time=start_time,
request_type="ForwardBackward",
queue_state_observer=self,
)
futures.append(api_future)
return await _CombinedAPIFuture(futures, combine_fwd_bwd_output_results, self.holder)
return self.holder.run_coroutine_threadsafe(_forward_backward_async())
async def forward_backward_async(
self, data: List[types.Datum], loss_fn: types.LossFnType
) -> APIFuture[types.ForwardBackwardOutput]:
return self.forward_backward(data, loss_fn)
@sync_only
@capture_exceptions(fatal=True)
def forward_backward_custom(
self, data: List[types.Datum], loss_fn: CustomLossFnV1
) -> APIFuture[types.ForwardBackwardOutput]:
"""Synchronous version of forward_backward_custom_async."""
return self.holder.run_coroutine_threadsafe(
self.forward_backward_custom_async(data, loss_fn)
).result()
@capture_exceptions(fatal=True)
async def forward_backward_custom_async(
self, data: List[types.Datum], loss_fn: CustomLossFnV1
) -> APIFuture[types.ForwardBackwardOutput]:
import torch
# First do a forward pass and get logprobs
forward_future = await self.forward_async(data, "cross_entropy")
forward_result = await forward_future.result_async()
logprobs_list: List[torch.Tensor] = []
for out in forward_result.loss_fn_outputs:
logprob = torch.tensor(out["logprobs"].data).clone().detach().requires_grad_(True)
logprobs_list.append(logprob)
# Now apply user-provided function
loss, metrics = loss_fn(data, logprobs_list)
loss.backward()
grads = []
for logprob in logprobs_list:
if logprob.grad is None:
raise ValueError("No gradient computed for logprob tensor")
grads.append(logprob.grad)
linear_loss_data = []
for datum, grad in zip(data, grads):
loss_fn_inputs: Any = {
"target_tokens": datum.loss_fn_inputs["target_tokens"],
"weights": -grad, # Pass PyTorch tensor directly (will be converted to TensorData)
}
linear_loss_data.append(
types.Datum(
model_input=datum.model_input,
loss_fn_inputs=loss_fn_inputs,
)
)
# Do the backward pass with the gradients
backward_future = await self.forward_backward_async(linear_loss_data, "cross_entropy")
# We need to slightly modify the future to add the custom metrics, so we use _CombinedAPIFuture
# to transform the future.
def add_custom_metrics(
results: List[types.ForwardBackwardOutput],
) -> types.ForwardBackwardOutput:
result = results[0] # Single result
result.metrics.update(metrics)
return result
return _CombinedAPIFuture([backward_future], add_custom_metrics, self.holder)
@capture_exceptions(fatal=True)
def optim_step(self, adam_params: types.AdamParams) -> APIFuture[types.OptimStepResponse]:
request_id = self._get_request_id()
@capture_exceptions(fatal=True)
async def _optim_step_async():
start_time = time.time()
async def _send_request():
with self.holder.aclient(ClientConnectionPoolType.TRAIN) as client:
return await client.training.optim_step(
model_id=self._guaranteed_model_id(),
adam_params=_to_adam_params(adam_params),
idempotency_key=self._make_idempotency_key(request_id),
)
async with self._take_turn(request_id):
untyped_future = await self.holder.execute_with_retries(_send_request)
return await _APIFuture(
types.OptimStepResponse,
self.holder,
untyped_future,
request_start_time=start_time,
request_type="OptimStep",
queue_state_observer=self,
)
return self.holder.run_coroutine_threadsafe(_optim_step_async())
async def optim_step_async(
self, adam_params: types.AdamParams
) -> APIFuture[types.OptimStepResponse]:
return self.optim_step(adam_params)
@capture_exceptions(fatal=True)
def save_state(self, name: str) -> APIFuture[types.SaveWeightsResponse]:
request_id = self._get_request_id()
@capture_exceptions(fatal=True)
async def _save_state_async():
start_time = time.time()
async def _send_request():
with self.holder.aclient(ClientConnectionPoolType.TRAIN) as client:
return await client.weights.save(
model_id=self._guaranteed_model_id(),
path=name,
idempotency_key=self._make_idempotency_key(request_id),
)
async with self._take_turn(request_id):
future = await self.holder.execute_with_retries(_send_request)
return await _APIFuture(
types.SaveWeightsResponse,
self.holder,
future,
request_start_time=start_time,
request_type="SaveWeights",
queue_state_observer=self,
)
return self.holder.run_coroutine_threadsafe(_save_state_async())
async def save_state_async(self, name: str) -> APIFuture[types.SaveWeightsResponse]:
return self.save_state(name)
@capture_exceptions(fatal=True)
def load_state(self, path: str) -> APIFuture[types.LoadWeightsResponse]:
request_id = self._get_request_id()
@capture_exceptions(fatal=True)
async def _load_state_async():
start_time = time.time()
async def _send_request():
with self.holder.aclient(ClientConnectionPoolType.TRAIN) as client:
return await client.weights.load(
model_id=self._guaranteed_model_id(),
path=path,
idempotency_key=self._make_idempotency_key(request_id),
)
async with self._take_turn(request_id):
future = await self.holder.execute_with_retries(_send_request)
return await _APIFuture(
types.LoadWeightsResponse,
self.holder,
future,
request_start_time=start_time,
request_type="LoadWeights",
queue_state_observer=self,
)
return self.holder.run_coroutine_threadsafe(_load_state_async())
async def load_state_async(self, path: str) -> APIFuture[types.LoadWeightsResponse]:
return self.load_state(path)
@capture_exceptions(fatal=True)
def save_weights_for_sampler(self, name: str) -> APIFuture[types.SaveWeightsForSamplerResponse]:
request_id = self._get_request_id()
@capture_exceptions(fatal=True)
async def _save_weights_for_sampler_async():
start_time = time.time()
async def _send_request():
with self.holder.aclient(ClientConnectionPoolType.TRAIN) as client:
return await client.weights.save_for_sampler(
model_id=self._guaranteed_model_id(),
path=name,
idempotency_key=self._make_idempotency_key(request_id),
)
async with self._take_turn(request_id):
future = await self.holder.execute_with_retries(_send_request)
return await _APIFuture(
types.SaveWeightsForSamplerResponse,
self.holder,
future,
request_start_time=start_time,
request_type="SaveWeightsForSampler",
queue_state_observer=self,
)
return self.holder.run_coroutine_threadsafe(_save_weights_for_sampler_async())
async def save_weights_for_sampler_async(
self, name: str
) -> APIFuture[types.SaveWeightsForSamplerResponse]:
return self.save_weights_for_sampler(name)
@capture_exceptions(fatal=True)
def unload_model(
self,
) -> APIFuture[types.UnloadModelResponse]:
request_id = self._get_request_id()
@capture_exceptions(fatal=True)
async def _unload_model_async():
start_time = time.time()
async def _send_request():
with self.holder.aclient(ClientConnectionPoolType.TRAIN) as client:
return await client.models.unload(
model_id=self._guaranteed_model_id(),
idempotency_key=self._make_idempotency_key(request_id),
)
async with self._take_turn(request_id):
future = await self.holder.execute_with_retries(_send_request)
return await _APIFuture(
types.UnloadModelResponse,
self.holder,
future,
request_start_time=start_time,
request_type="UnloadModel",
queue_state_observer=self,
)
return self.holder.run_coroutine_threadsafe(_unload_model_async())
async def unload_model_async(self) -> APIFuture[types.UnloadModelResponse]:
return self.unload_model()
def _get_info_submit(self) -> AwaitableConcurrentFuture[types.GetInfoResponse]:
request_id = self._get_request_id()
async def _get_info_async():
async def _send_request():
with self.holder.aclient(ClientConnectionPoolType.TRAIN) as client:
return await client.models.get_info(
model_id=self._guaranteed_model_id(),
idempotency_key=self._make_idempotency_key(request_id),
)
async with self._take_turn(request_id):
return await self.holder.execute_with_retries(_send_request)
return self.holder.run_coroutine_threadsafe(_get_info_async())
@sync_only
@capture_exceptions(fatal=True)
def get_info(self) -> types.GetInfoResponse:
return self._get_info_submit().result()
@capture_exceptions(fatal=True)
async def get_info_async(self) -> types.GetInfoResponse:
return await self._get_info_submit()
@cache
@capture_exceptions(fatal=True)
def get_tokenizer(self) -> PreTrainedTokenizer:
return _get_tokenizer(self._guaranteed_model_id(), self.holder)
@capture_exceptions(fatal=True)
def create_sampling_client(
self, model_path: str, retry_config: RetryConfig | None = None
) -> SamplingClient:
from .sampling_client import SamplingClient
return SamplingClient(self.holder, model_path=model_path, retry_config=retry_config)
@capture_exceptions(fatal=True)
def save_weights_and_get_sampling_client(
self, name: str, retry_config: RetryConfig | None = None
) -> SamplingClient:
from .sampling_client import SamplingClient
path = self.save_weights_for_sampler(name).result().path
return SamplingClient(self.holder, model_path=path, retry_config=retry_config)
@capture_exceptions(fatal=True)
async def save_weights_and_get_sampling_client_async(
self, name: str, retry_config: RetryConfig | None = None
) -> SamplingClient:
from .sampling_client import SamplingClient
save_weights_future = await self.save_weights_for_sampler_async(name)
save_weights_result = await save_weights_future.result_async()
model_path = save_weights_result.path
return SamplingClient(self.holder, model_path=model_path, retry_config=retry_config)
def get_telemetry(self) -> Telemetry | None:
return self.holder.get_telemetry()
def on_queue_state_change(self, queue_state: QueueState) -> None:
QUEUE_STATE_LOG_INTERVAL = 60
if queue_state == QueueState.ACTIVE:
return
if time.time() - self._last_queue_state_logged < QUEUE_STATE_LOG_INTERVAL:
return
self._last_queue_state_logged = time.time()
if queue_state == QueueState.PAUSED_RATE_LIMIT:
reason = "concurrent models rate limit hit"
elif queue_state == QueueState.PAUSED_CAPACITY:
reason = "out of capacity"
else:
reason = "unknown"
logger.warning(f"Training is paused for {self.model_id}. Reason: {reason}")
def _to_fwdbwd_input_params(x: types.ForwardBackwardInput) -> types._ForwardBackwardInputParam:
return cast(types._ForwardBackwardInputParam, x.model_dump())
def _to_adam_params(x: types.AdamParams) -> training_optim_step_params.AdamParams:
return cast(training_optim_step_params.AdamParams, x.model_dump())
def _get_tokenizer(model_id: types.ModelID, holder: InternalClientHolder) -> PreTrainedTokenizer:
# call get_info on model_id
from transformers.models.auto.tokenization_auto import AutoTokenizer
async def _get_info_async():
with holder.aclient(ClientConnectionPoolType.TRAIN) as client:
return await client.models.get_info(model_id=model_id)
info = holder.run_coroutine_threadsafe(_get_info_async()).result()
model_name = info.model_data.model_name
assert model_name is not None, "This shouldn't happen: model_name is None"
# We generally adhere to the huggingface convention of "<org>/<model>" but
# in some cases we'll deploy variants using the format
# "<org>/<model>/<variant>". In that case, we want to load the tokenizer
# using the huggingface convention.
if model_name.startswith("meta-llama/Llama-3"):
# Avoid gating of Llama 3 models:
tokenizer_id = "baseten/Meta-Llama-3-tokenizer"
elif model_name.count("/") == 2:
org, model, _variant = model_name.split("/", 2)
tokenizer_id = f"{org}/{model}"
else:
tokenizer_id = model_name
return AutoTokenizer.from_pretrained(tokenizer_id, fast=True)

View file

@ -0,0 +1,278 @@
"""
Generalizable retry handler for API requests with connection limiting,
progress tracking, and exponential backoff.
"""
from __future__ import annotations
import asyncio
import logging
import random
import time
import traceback
from collections import defaultdict
from dataclasses import dataclass
from typing import Any, Awaitable, Callable, Generic, Type, TypeVar
import httpx
import tinker
from tinker.lib.telemetry import Telemetry
from .._constants import (
DEFAULT_CONNECTION_LIMITS,
INITIAL_RETRY_DELAY,
MAX_RETRY_DELAY,
)
from .retryable_exception import RetryableException
logger = logging.getLogger(__name__)
T = TypeVar("T")
def is_retryable_status_code(status_code: int) -> bool:
return status_code in (408, 409, 429) or (500 <= status_code < 600)
@dataclass
class RetryConfig:
max_connections: int = DEFAULT_CONNECTION_LIMITS.max_connections or 100
progress_timeout: float = 30 * 60 # Very long straggler
retry_delay_base: float = INITIAL_RETRY_DELAY
retry_delay_max: float = MAX_RETRY_DELAY
jitter_factor: float = 0.25
enable_retry_logic: bool = True
retryable_exceptions: tuple[Type[Exception], ...] = (
asyncio.TimeoutError,
tinker.APIConnectionError,
httpx.TimeoutException,
RetryableException,
)
def __post_init__(self):
if self.max_connections <= 0:
raise ValueError(f"max_connections must be positive, got {self.max_connections}")
def __hash__(self):
return hash(
(
self.max_connections,
self.progress_timeout,
self.retry_delay_base,
self.retry_delay_max,
self.jitter_factor,
self.enable_retry_logic,
self.retryable_exceptions,
)
)
class RetryHandler(Generic[T]): # noqa: UP046
"""
A generalizable retry handler for API requests.
Features:
- Connection limiting with semaphores
- Global progress timeout tracking
- Exponential backoff with jitter
- Configurable error classification
Usage:
handler = RetryHandler(config=retry_config)
result = await handler.execute(my_function, *args, **kwargs)
"""
def __init__(
self,
config: RetryConfig = RetryConfig(),
name: str = "default",
telemetry: Telemetry | None = None,
):
self.config = config
self.name = name
self._telemetry = telemetry
current_time = time.time()
self._last_global_progress = current_time
self._last_printed_progress = current_time
self._processed_count = 0
self._waiting_at_semaphore_count = 0
self._in_retry_loop_count = 0
self._retry_count = 0
self._exception_counts = {} # Track exception types and their counts
self._errors_since_last_retry: defaultdict[str, int] = defaultdict(int)
# The semaphore is used to limit the number of concurrent requests.
# Without a semaphore, progress can grind to a halt as requests fight
# for limited httpx connections.
self._semaphore = asyncio.Semaphore(config.max_connections)
async def execute(self, func: Callable[..., Awaitable[T]], *args: Any, **kwargs: Any) -> T:
"""Use as a direct function call."""
self._waiting_at_semaphore_count += 1
async with self._semaphore:
self._waiting_at_semaphore_count -= 1
if self._in_retry_loop_count == 0:
self._last_global_progress = time.time()
self._in_retry_loop_count += 1
self._maybe_log_progress()
async def _check_progress(parent_task: asyncio.Task[T]):
while True:
deadline = self._last_global_progress + self.config.progress_timeout
if time.time() > deadline:
parent_task._no_progress_made_marker = True # pyright: ignore[reportAttributeAccessIssue]
parent_task.cancel()
await asyncio.sleep(deadline - time.time())
current_task = asyncio.current_task()
assert current_task is not None
current_task._no_progress_made_marker = False # pyright: ignore[reportAttributeAccessIssue]
progress_task = asyncio.create_task(_check_progress(current_task))
try:
result = await self._execute_with_retry(func, *args, **kwargs)
self._last_global_progress = time.time()
return result
except asyncio.CancelledError:
if current_task._no_progress_made_marker: # pyright: ignore[reportAttributeAccessIssue]
current_task.uncancel()
# Create a dummy request for the exception (required by APIConnectionError)
dummy_request = httpx.Request("GET", "http://localhost")
raise tinker.APIConnectionError(
message=f"No progress made in {self.config.progress_timeout}s. Requests appear to be stuck.",
request=dummy_request,
)
raise
finally:
self._in_retry_loop_count -= 1
self._maybe_log_progress()
progress_task.cancel()
def _maybe_log_progress(self):
current_time = time.time()
elapsed_since_last_printed_progress = current_time - self._last_printed_progress
finished = self._waiting_at_semaphore_count + self._in_retry_loop_count == 0
if elapsed_since_last_printed_progress > 2 or finished:
logger.debug(
f"[{self.name}]: {self._waiting_at_semaphore_count} waiting, {self._in_retry_loop_count} in progress, {self._processed_count} completed"
)
if self._errors_since_last_retry:
sorted_items = sorted(
self._errors_since_last_retry.items(), key=lambda x: x[1], reverse=True
)
logger.debug(
f"[{self.name}]: {self._retry_count} total retries, errors since last log: {sorted_items}"
)
self._last_printed_progress = current_time
self._errors_since_last_retry.clear()
async def _execute_with_retry(
self, func: Callable[..., Awaitable[T]], *args: Any, **kwargs: Any
) -> T:
"""Main retry logic."""
# Fast path: skip all retry logic if disabled
if not self.config.enable_retry_logic:
return await func(*args, **kwargs)
start_time = time.time()
attempt_count = 0
while True:
current_time = time.time()
self._maybe_log_progress()
try:
attempt_count += 1
logger.debug(f"Attempting request (attempt #{attempt_count})")
result = await func(*args, **kwargs)
except Exception as e:
exception_str = f"{type(e).__name__}: {str(e) or 'No error message'}"
self._errors_since_last_retry[exception_str] += 1
should_retry = self._should_retry(e)
if telemetry := self.get_telemetry():
current_time = time.time()
telemetry.log(
"RetryHandler.execute.exception",
event_data={
"func": getattr(
func, "__qualname__", getattr(func, "__name__", type(func).__name__)
),
"exception": str(e),
"exception_type": type(e).__name__,
"exception_stack": "".join(
traceback.format_exception(type(e), e, e.__traceback__)
)
if e.__traceback__
else None,
"status_code": getattr(e, "status_code", None),
"should_retry": should_retry,
"attempt_count": attempt_count,
"start_time": start_time,
"current_time": current_time,
"elapsed_time": current_time - start_time,
},
severity="WARNING" if should_retry else "ERROR",
)
if not should_retry:
logger.error(f"Request failed with non-retryable error: {exception_str}")
raise
self._log_retry_reason(e, attempt_count)
self._retry_count += 1
# Calculate retry delay with exponential backoff and jitter
retry_delay = self._calculate_retry_delay(attempt_count - 1)
logger.debug(f"Retrying in {retry_delay:.2f}s")
await asyncio.sleep(retry_delay)
else:
logger.debug(f"Request succeeded after {attempt_count} attempts")
self._processed_count += 1
return result
def _should_retry(self, exception: Exception) -> bool:
"""Determine if an exception should trigger a retry."""
# Check if it's a generally retryable exception type
if isinstance(exception, self.config.retryable_exceptions):
return True
# Check for API status errors with retryable status codes
if isinstance(exception, tinker.APIStatusError):
return is_retryable_status_code(exception.status_code)
return False
def _log_retry_reason(self, exception: Exception, attempt_count: int):
"""Log the reason for retrying."""
if isinstance(exception, asyncio.TimeoutError):
logger.debug("Request timed out")
elif isinstance(exception, tinker.APIConnectionError):
logger.debug(f"Request failed with connection error: {exception}")
elif isinstance(exception, tinker.APIStatusError):
logger.debug(
f"Request attempt #{attempt_count} failed with status {exception.status_code}"
)
else:
logger.debug(f"Request attempt #{attempt_count} failed with error: {exception}")
def _calculate_retry_delay(self, attempt: int) -> float:
"""Calculate retry delay with exponential backoff and jitter."""
delay = self.config.retry_delay_max
try:
delay = min(self.config.retry_delay_base * (2**attempt), self.config.retry_delay_max)
except OverflowError:
# There are two possible overflow errors:
# (1) `min` tries to convert the value to a float, which can overflow
# if the integer value gets too large
# (2) If the attempt number is too large, the `2 ** attempt` will overflow
delay = self.config.retry_delay_max
jitter = delay * self.config.jitter_factor * (2 * random.random() - 1)
# Ensure the final delay doesn't exceed the maximum, even with jitter
return max(0, min(delay + jitter, self.config.retry_delay_max))
def get_telemetry(self) -> Telemetry | None:
return self._telemetry

View file

@ -0,0 +1,3 @@
class RetryableException(Exception):
def __init__(self, message: str):
super().__init__(message)

View file

@ -0,0 +1,79 @@
"""
Decorator to prevent synchronous methods from being called in async contexts.
This helps users avoid a common footgun where calling sync methods from async code
can cause deadlocks and performance issues.
"""
import asyncio
import logging
import traceback
from functools import wraps
from typing import Any, Callable, TypeVar
logger = logging.getLogger(__name__)
T = TypeVar("T")
def is_jupyter() -> bool:
"""Check if code is running in a Jupyter notebook."""
try:
get_ipython # type: ignore
except NameError:
return False
shell = get_ipython().__class__.__name__ # type: ignore
if shell in ("ZMQInteractiveShell", "Shell"):
return True # Jupyter notebook or qtconsole
return False # Other type of shell
def is_in_async_context() -> bool:
try:
asyncio.get_running_loop()
except RuntimeError:
return False
return True
def make_error_message( # noqa: UP047
func: Callable[..., T], args: tuple[Any, ...], kwargs: dict[str, Any]
) -> str:
# If we get here, we're in an async context - this is bad!
method_name = func.__name__
async_method_name = f"{method_name}_async"
# Get the class name for a better error message
class_name = ""
if args and hasattr(args[0], "__class__"):
class_name = f"{args[0].__class__.__name__}."
return (
f"Synchronous method '{class_name}{method_name}()' called from async context. "
f"Use '{class_name}{async_method_name}()' instead.\n"
f"Calling sync methods from async code can cause deadlocks and performance issues."
)
def sync_only(func: Callable[..., T]) -> Callable[..., T]: # noqa: UP047
"""Decorator to ensure a method is only called from sync context.
This helps prevent a common footgun where users accidentally call
sync methods from async code, which can cause deadlocks and performance issues.
Assumes (in error message) that the wrapped method has a corresponding method name
{method_name}_async.
"""
@wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> T:
if is_in_async_context() and not is_jupyter():
error_message = make_error_message(func, args, kwargs)
logger.warning(error_message)
logger.warning(
f"===== Stack for calling sync from async ===== \n{traceback.format_stack()}\n ==========="
)
return func(*args, **kwargs)
return wrapper

383
src/tinker/lib/telemetry.py Normal file
View file

@ -0,0 +1,383 @@
import asyncio
import contextlib
import functools
import inspect
import logging
import os
import platform
import threading
import traceback
from collections import deque
from collections.abc import Awaitable
from datetime import datetime, timezone
from typing import (
Callable,
ParamSpec,
TypeVar,
cast,
overload,
)
from uuid import uuid4
from tinker._exceptions import APIError
from tinker._version import __version__
from tinker.types.generic_event import GenericEvent
from tinker.types.session_end_event import SessionEndEvent
from tinker.types.session_start_event import SessionStartEvent
from tinker.types.severity import Severity
from tinker.types.telemetry_batch import TelemetryBatch
from tinker.types.telemetry_event import TelemetryEvent
from tinker.types.telemetry_response import TelemetryResponse
from tinker.types.telemetry_send_params import TelemetrySendParams
from tinker.types.unhandled_exception_event import UnhandledExceptionEvent
from .async_tinker_provider import AsyncTinkerProvider
from .client_connection_pool_type import ClientConnectionPoolType
from .sync_only import sync_only
from .telemetry_provider import TelemetryProvider
logger = logging.getLogger(__name__)
MAX_BATCH_SIZE: int = 100
FLUSH_INTERVAL: float = 10.0
FLUSH_TIMEOUT: float = 30.0
MAX_QUEUE_SIZE: int = 10000
HTTP_TIMEOUT_SECONDS: float = 5.0
class Telemetry:
def __init__(self, tinker_provider: AsyncTinkerProvider, session_id: str):
self._tinker_provider: AsyncTinkerProvider = tinker_provider
self._session_id: str = session_id
self._session_start: datetime = datetime.now(timezone.utc)
self._session_index: int = 0
self._session_index_lock: threading.Lock = threading.Lock()
self._queue: deque[TelemetryEvent] = deque()
self._queue_lock: threading.Lock = threading.Lock()
self._task: asyncio.Task[None] | None = None
self._flush_event: asyncio.Event | None = None
self._push_counter: int = 0
self._flush_counter: int = 0
self._counter_lock: threading.Lock = threading.Lock()
_ = self._log(self._session_start_event())
self._start()
def _start(self):
def cb():
self._flush_event = asyncio.Event()
self._task = asyncio.create_task(self._periodic_flush(), name="tinker-telemetry")
_ = self._tinker_provider.get_loop().call_soon_threadsafe(cb)
def stop(self):
def cb():
if task := self._task:
_ = task.cancel()
_ = self._tinker_provider.get_loop().call_soon_threadsafe(cb)
async def _periodic_flush(self):
while True:
if self._flush_event:
try:
_ = await asyncio.wait_for(self._flush_event.wait(), timeout=FLUSH_INTERVAL)
except TimeoutError:
pass
finally:
self._flush_event.clear()
await self._flush()
async def _flush(self):
while True:
with self._queue_lock:
if not self._queue:
break
batch_size = min(MAX_BATCH_SIZE, len(self._queue))
events = [self._queue.popleft() for _ in range(batch_size)]
batch = self._batch(events)
try:
_ = await self._send_batch_with_retry(batch)
finally:
# increment counter even if we fail to send the batch so we're not blocking
# on a flush for non-APIErrors (e.g. missing API key)
with self._counter_lock:
self._flush_counter += len(events)
def _trigger_flush(self):
if self._flush_event:
_ = self._tinker_provider.get_loop().call_soon_threadsafe(self._flush_event.set)
async def _wait_until_drained(self) -> bool:
with self._counter_lock:
target_count = self._push_counter
start = asyncio.get_event_loop().time()
while asyncio.get_event_loop().time() - start < FLUSH_TIMEOUT:
with self._counter_lock:
if self._flush_counter >= target_count:
return True
await asyncio.sleep(0.1)
return False
def _wait_until_drained_sync(self) -> bool:
try:
return asyncio.run_coroutine_threadsafe(
self._wait_until_drained(), self._tinker_provider.get_loop()
).result(timeout=FLUSH_TIMEOUT)
except (TimeoutError, asyncio.CancelledError):
return False
async def _send_batch_with_retry(self, batch: TelemetryBatch) -> TelemetryResponse:
while True:
try:
return await self._send_batch(batch)
except APIError as e:
logger.warning("Failed to send telemetry batch", exc_info=e)
await asyncio.sleep(1)
continue
async def _send_batch(self, batch: TelemetryBatch) -> TelemetryResponse:
with self._tinker_provider.aclient(ClientConnectionPoolType.TELEMETRY) as client:
params = _to_send_params(batch)
return await client.telemetry.send(**params, timeout=HTTP_TIMEOUT_SECONDS)
def _log(self, *events: TelemetryEvent) -> bool:
with self._queue_lock:
if len(self._queue) + len(events) > MAX_QUEUE_SIZE:
logger.warning("Telemetry queue full, dropping events")
return False
self._queue.extend(events)
with self._counter_lock:
self._push_counter += len(events)
return True
def log(
self,
event_name: str, # should be low cardinality
event_data: dict[str, object] | None = None,
severity: Severity = "INFO",
) -> bool:
return self._log(self._generic_event(event_name, event_data, severity))
async def log_exception(self, exception: BaseException, severity: Severity = "ERROR") -> bool:
logged = self._log(self._exception_event(exception, severity))
# trigger flush but don't block on it
self._trigger_flush()
return logged
async def log_fatal_exception(
self, exception: BaseException, severity: Severity = "ERROR"
) -> bool:
logged = self._log(self._exception_event(exception, severity), self._session_end_event())
self._trigger_flush()
# wait for the flush to complete
_ = await self._wait_until_drained()
if logged:
self._notify_exception_logged()
return logged
@sync_only
def log_exception_sync(self, exception: BaseException, severity: Severity = "ERROR") -> bool:
logged = self._log(self._exception_event(exception, severity))
# trigger flush but don't block on it
self._trigger_flush()
return logged
@sync_only
def log_fatal_exception_sync(
self, exception: BaseException, severity: Severity = "ERROR"
) -> bool:
logged = self._log(self._exception_event(exception, severity), self._session_end_event())
self._trigger_flush()
# wait for the flush to complete
if _current_loop() is None:
_ = self._wait_until_drained_sync()
if logged:
self._notify_exception_logged()
return logged
def _notify_exception_logged(self):
logger.info(f"Exception logged for session ID: {self._session_id}")
def _batch(self, events: list[TelemetryEvent]) -> TelemetryBatch:
return TelemetryBatch(
platform=platform.system(),
sdk_version=__version__,
session_id=self._session_id,
events=events,
)
def _generic_event(
self,
event_name: str,
event_data: dict[str, object] | None = None,
severity: Severity = "INFO",
) -> GenericEvent:
return GenericEvent(
event="GENERIC_EVENT",
event_id=str(uuid4()),
event_session_index=self._next_session_index(),
timestamp=datetime.now(timezone.utc),
severity=severity,
event_name=event_name,
event_data=event_data or {},
)
def _session_start_event(self) -> SessionStartEvent:
return SessionStartEvent(
event="SESSION_START",
event_id=str(uuid4()),
event_session_index=self._next_session_index(),
timestamp=self._session_start,
severity="INFO",
)
def _session_end_event(self) -> SessionEndEvent:
end_time = datetime.now(timezone.utc)
return SessionEndEvent(
event="SESSION_END",
event_id=str(uuid4()),
event_session_index=self._next_session_index(),
timestamp=end_time,
severity="INFO",
duration=str(end_time - self._session_start),
)
def _exception_event(
self, exception: BaseException, severity: Severity
) -> UnhandledExceptionEvent:
return UnhandledExceptionEvent(
event="UNHANDLED_EXCEPTION",
event_id=str(uuid4()),
event_session_index=self._next_session_index(),
timestamp=datetime.now(timezone.utc),
severity=severity,
error_type=exception.__class__.__name__,
error_message=str(exception),
traceback="".join(
traceback.format_exception(type(exception), exception, exception.__traceback__)
)
if exception.__traceback__
else None,
)
def _next_session_index(self) -> int:
with self._session_index_lock:
idx = self._session_index
self._session_index += 1
return idx
@contextlib.contextmanager
def capture_exceptions(self, fatal: bool = False, severity: Severity = "ERROR"):
try:
yield
except Exception as e:
if fatal:
_ = self.log_fatal_exception_sync(e, severity)
else:
_ = self.log_exception_sync(e, severity)
raise
@contextlib.asynccontextmanager
async def acapture_exceptions(self, fatal: bool = False, severity: Severity = "ERROR"):
try:
yield
except Exception as e:
if fatal:
_ = await self.log_fatal_exception(e, severity)
else:
_ = await self.log_exception(e, severity)
raise
def _is_telemetry_enabled() -> bool:
return os.environ.get("TINKER_TELEMETRY", "1").lower() in {
"1",
"true",
"yes",
"on",
}
def init_telemetry(tinker_provider: AsyncTinkerProvider, session_id: str) -> Telemetry | None:
try:
return Telemetry(tinker_provider, session_id) if _is_telemetry_enabled() else None
except Exception as e:
logger.warning(f"Error initializing telemetry: {e}")
return None
P = ParamSpec("P")
R = TypeVar("R")
# Decorator to capture exceptions. Class must implement TelemetryProvider.
# Pass fatal=True to log a session end event in addition to the exception.
#
# Example:
# @capture_exceptions
# def my_method(self):
# pass
#
# @capture_exceptions(fatal=True, severity="CRITICAL")
# def my_method(self):
# pass
@overload
def capture_exceptions(
func: Callable[P, R], *, fatal: bool = False, severity: Severity = "ERROR"
) -> Callable[P, R]: ...
@overload
def capture_exceptions(
*, fatal: bool = False, severity: Severity = "ERROR"
) -> Callable[[Callable[P, R]], Callable[P, R]]: ...
def capture_exceptions(
func: Callable[P, R] | None = None, *, fatal: bool = False, severity: Severity = "ERROR"
) -> Callable[P, R] | Callable[[Callable[P, R]], Callable[P, R]]:
def _get_telemetry(func: Callable[..., object], args: tuple[object, ...]) -> Telemetry | None:
if args and isinstance(args[0], TelemetryProvider):
return args[0].get_telemetry()
with contextlib.suppress(TypeError, AttributeError):
self = inspect.getclosurevars(func).nonlocals.get("self")
if isinstance(self, TelemetryProvider):
return self.get_telemetry()
logger.warning("@capture_exceptions used without TelemetryProvider: %s", func.__name__)
return None
def _decorate(func: Callable[P, R]) -> Callable[P, R]:
if inspect.iscoroutinefunction(func):
@functools.wraps(func)
async def _awrapper(*args: P.args, **kwargs: P.kwargs) -> R:
telemetry = _get_telemetry(func, args)
if telemetry is None:
return await cast(Callable[..., Awaitable[R]], func)(*args, **kwargs)
async with telemetry.acapture_exceptions(fatal=fatal, severity=severity):
return await cast(Callable[..., Awaitable[R]], func)(*args, **kwargs)
return cast(Callable[P, R], _awrapper)
@functools.wraps(func)
def _wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
telemetry = _get_telemetry(func, args)
if telemetry is None:
return func(*args, **kwargs)
with telemetry.capture_exceptions(fatal=fatal, severity=severity):
return func(*args, **kwargs)
return cast(Callable[P, R], _wrapper)
return _decorate if func is None else _decorate(func)
def _to_send_params(batch: TelemetryBatch) -> TelemetrySendParams:
return cast(TelemetrySendParams, cast(object, batch.model_dump()))
def _current_loop() -> asyncio.AbstractEventLoop | None:
try:
return asyncio.get_running_loop()
except RuntimeError:
return None

View file

@ -0,0 +1,11 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Protocol, runtime_checkable
if TYPE_CHECKING:
from .telemetry import Telemetry
@runtime_checkable
class TelemetryProvider(Protocol):
def get_telemetry(self) -> Telemetry | None: ...

View file

@ -0,0 +1,642 @@
import asyncio
import contextlib
import os
import platform
import threading
from collections.abc import Coroutine
from concurrent.futures import Future as ConcurrentFuture
from typing import Any, TypeVar, cast
from unittest.mock import AsyncMock, MagicMock, Mock, patch
import pytest
from tinker.lib.public_interfaces.api_future import AwaitableConcurrentFuture
from tinker.lib.client_connection_pool_type import ClientConnectionPoolType
from tinker.lib.telemetry import (
MAX_BATCH_SIZE,
MAX_QUEUE_SIZE,
Telemetry,
_is_telemetry_enabled,
capture_exceptions,
init_telemetry,
)
from tinker.lib.telemetry_provider import TelemetryProvider
from tinker.types.generic_event import GenericEvent
from tinker.types.session_end_event import SessionEndEvent
from tinker.types.session_start_event import SessionStartEvent
from tinker.types.telemetry_batch import TelemetryBatch
from tinker.types.telemetry_event import TelemetryEvent
from tinker.types.telemetry_response import TelemetryResponse
from tinker.types.unhandled_exception_event import UnhandledExceptionEvent
# pyright: reportMissingParameterType=false
# pyright: reportOptionalMemberAccess=false
T = TypeVar("T")
class MockEventLoopProvider:
def __init__(self):
self.loop = Mock()
self.loop.call_soon_threadsafe = Mock()
def get_loop(self):
return self.loop
def run_coroutine_threadsafe(self, coro: Coroutine[Any, Any, T]):
future = Mock()
future.result = Mock(return_value=asyncio.run(coro))
return future
class MockAsyncTinkerProvider:
def __init__(self):
self.loop = Mock()
self.callbacks = []
self.loop.call_soon_threadsafe = Mock(side_effect=lambda cb: self.callbacks.append(cb))
self._client = MagicMock()
self.telemetry_send_mock = AsyncMock(return_value=TelemetryResponse(status="accepted"))
self._client.telemetry.send = self.telemetry_send_mock
self._client_cm = MagicMock()
self._client_cm.__enter__ = Mock(return_value=self._client)
self._client_cm.__exit__ = Mock(return_value=None)
def execute_callbacks(self):
for cb in self.callbacks:
cb()
self.callbacks.clear()
def get_loop(self):
return self.loop
def run_coroutine_threadsafe(self, coro: Coroutine[Any, Any, T]):
fut: ConcurrentFuture[Any] = ConcurrentFuture()
async def _runner():
try:
result = await coro
fut.set_result(result)
except Exception as e:
fut.set_exception(e)
_ = asyncio.get_event_loop_policy().get_event_loop().create_task(_runner())
return AwaitableConcurrentFuture(fut)
def aclient(self, client_pool_type: ClientConnectionPoolType):
_ = client_pool_type
return self._client_cm
class TestTelemetryClass:
def setup_method(self):
self.tinker_provider = MockAsyncTinkerProvider()
self.telemetry = Telemetry(self.tinker_provider, session_id="test-session-id")
def teardown_method(self):
if hasattr(self, "telemetry"):
self.telemetry.stop()
@pytest.mark.asyncio
async def test_initialization(self):
self.tinker_provider.execute_callbacks()
assert self.telemetry._session_index == 1
assert len(self.telemetry._queue) == 1
assert isinstance(self.telemetry._queue[0], SessionStartEvent)
assert isinstance(self.telemetry._flush_event, asyncio.Event)
def test_log_single_event(self):
event = self.telemetry._session_end_event()
result = self.telemetry._log(event)
assert result is True
assert len(self.telemetry._queue) == 2
assert self.telemetry._queue[-1] == event
def test_log_multiple_events(self):
event1 = self.telemetry._session_end_event()
event2 = self.telemetry._exception_event(ValueError("test"), "ERROR")
result = self.telemetry._log(event1, event2)
assert result is True
assert len(self.telemetry._queue) == 3
assert self.telemetry._queue[-2] == event1
assert self.telemetry._queue[-1] == event2
def test_log_generic_event_default(self):
idx_before = self.telemetry._session_index
result = self.telemetry.log("test-event")
assert result is True
assert len(self.telemetry._queue) == 2
event = self.telemetry._queue[-1]
assert isinstance(event, GenericEvent)
assert event.event == "GENERIC_EVENT"
assert event.event_name == "test-event"
assert event.severity == "INFO"
assert event.event_data == {}
assert event.event_session_index == idx_before
assert isinstance(event.event_id, str) and event.event_id
def test_log_generic_event_custom(self):
idx_before = self.telemetry._session_index
payload: dict[str, object] = {"a": 1, "b": "x"}
result = self.telemetry.log("custom-event", event_data=payload, severity="WARNING")
assert result is True
assert len(self.telemetry._queue) == 2
event = self.telemetry._queue[-1]
assert isinstance(event, GenericEvent)
assert event.event == "GENERIC_EVENT"
assert event.event_name == "custom-event"
assert event.severity == "WARNING"
assert event.event_data == payload
assert event.event_session_index == idx_before
def test_log_queue_full(self):
initial_size = len(self.telemetry._queue)
events_to_add = MAX_QUEUE_SIZE - initial_size - 1
for _ in range(events_to_add):
self.telemetry._queue.append(self.telemetry._session_end_event())
assert len(self.telemetry._queue) == MAX_QUEUE_SIZE - 1
event1 = self.telemetry._session_end_event()
event2 = self.telemetry._session_end_event()
with patch("tinker.lib.telemetry.logger") as mock_logger:
result = self.telemetry._log(event1, event2)
assert result is False
assert len(self.telemetry._queue) == MAX_QUEUE_SIZE - 1
mock_logger.warning.assert_called_once_with("Telemetry queue full, dropping events")
def test_batch_creation(self):
events = [
self.telemetry._session_start_event(),
self.telemetry._session_end_event(),
]
batch = self.telemetry._batch(cast(list[TelemetryEvent], events))
assert isinstance(batch, TelemetryBatch)
assert batch.platform == platform.system()
assert batch.session_id == str(self.telemetry._session_id)
assert batch.events == events
assert batch.sdk_version is not None
def test_log_exception_sync(self):
try:
raise RuntimeError("Test exception")
except RuntimeError as e:
with patch.object(self.telemetry, "_trigger_flush") as mock_trigger:
result = self.telemetry.log_exception_sync(e, "ERROR")
assert result is True
assert len(self.telemetry._queue) == 2
mock_trigger.assert_called_once()
exception_event = self.telemetry._queue[-1]
assert isinstance(exception_event, UnhandledExceptionEvent)
assert exception_event.error_type == "RuntimeError"
assert exception_event.error_message == "Test exception"
@pytest.mark.asyncio
async def test_log_exception_async(self):
try:
raise RuntimeError("Test exception")
except RuntimeError as e:
with patch.object(self.telemetry, "_trigger_flush") as mock_trigger:
result = await self.telemetry.log_exception(e, "ERROR")
assert result is True
assert len(self.telemetry._queue) == 2
mock_trigger.assert_called_once()
exception_event = self.telemetry._queue[-1]
assert isinstance(exception_event, UnhandledExceptionEvent)
assert exception_event.error_type == "RuntimeError"
assert exception_event.error_message == "Test exception"
def test_log_fatal_exception_sync(self):
try:
raise RuntimeError("Fatal error")
except RuntimeError as e:
with patch.object(self.telemetry, "_trigger_flush") as mock_trigger:
with patch.object(
self.telemetry, "_wait_until_drained_sync", return_value=True
) as mock_wait:
result = self.telemetry.log_fatal_exception_sync(e, "CRITICAL")
assert result is True
assert len(self.telemetry._queue) == 3
mock_trigger.assert_called_once()
mock_wait.assert_called_once()
exception_event = self.telemetry._queue[-2]
assert isinstance(exception_event, UnhandledExceptionEvent)
assert exception_event.severity == "CRITICAL"
end_event = self.telemetry._queue[-1]
assert isinstance(end_event, SessionEndEvent)
@pytest.mark.asyncio
async def test_log_fatal_exception_async(self):
try:
raise RuntimeError("Fatal error")
except RuntimeError as e:
with patch.object(self.telemetry, "_trigger_flush") as mock_trigger:
with patch.object(
self.telemetry, "_wait_until_drained", new_callable=AsyncMock, return_value=True
) as mock_wait:
result = await self.telemetry.log_fatal_exception(e, "CRITICAL")
assert result is True
assert len(self.telemetry._queue) == 3
mock_trigger.assert_called_once()
mock_wait.assert_called_once()
exception_event = self.telemetry._queue[-2]
assert isinstance(exception_event, UnhandledExceptionEvent)
assert exception_event.severity == "CRITICAL"
end_event = self.telemetry._queue[-1]
assert isinstance(end_event, SessionEndEvent)
def test_capture_exceptions_context_manager(self):
with patch.object(self.telemetry, "_trigger_flush"):
with pytest.raises(ValueError):
with self.telemetry.capture_exceptions():
raise ValueError("Test error")
assert len(self.telemetry._queue) == 2
exception_event = self.telemetry._queue[-1]
assert isinstance(exception_event, UnhandledExceptionEvent)
assert exception_event.error_type == "ValueError"
def test_capture_exceptions_context_manager_fatal(self):
with patch.object(self.telemetry, "_trigger_flush"):
with patch.object(self.telemetry, "_wait_until_drained_sync", return_value=True):
with pytest.raises(ValueError):
with self.telemetry.capture_exceptions(fatal=True, severity="CRITICAL"):
raise ValueError("Fatal error")
assert len(self.telemetry._queue) == 3
exception_event = self.telemetry._queue[-2]
assert isinstance(exception_event, UnhandledExceptionEvent)
assert exception_event.severity == "CRITICAL"
end_event = self.telemetry._queue[-1]
assert isinstance(end_event, SessionEndEvent)
@pytest.mark.asyncio
async def test_acapture_exceptions_context_manager(self):
with patch.object(self.telemetry, "_trigger_flush"):
with pytest.raises(ValueError):
async with self.telemetry.acapture_exceptions():
raise ValueError("Async test error")
assert len(self.telemetry._queue) == 2
exception_event = self.telemetry._queue[-1]
assert isinstance(exception_event, UnhandledExceptionEvent)
assert exception_event.error_type == "ValueError"
@pytest.mark.asyncio
async def test_acapture_exceptions_context_manager_fatal(self):
with patch.object(self.telemetry, "_trigger_flush"):
with patch.object(
self.telemetry, "_wait_until_drained", new_callable=AsyncMock, return_value=True
):
with pytest.raises(ValueError):
async with self.telemetry.acapture_exceptions(fatal=True):
raise ValueError("Async fatal error")
assert len(self.telemetry._queue) == 3
exception_event = self.telemetry._queue[-2]
assert isinstance(exception_event, UnhandledExceptionEvent)
end_event = self.telemetry._queue[-1]
assert isinstance(end_event, SessionEndEvent)
class TestTelemetryEnvironment:
@pytest.mark.parametrize(
"env_value,expected",
[
("1", True),
("true", True),
("True", True),
("TRUE", True),
("yes", True),
("Yes", True),
("on", True),
("ON", True),
("0", False),
("false", False),
("no", False),
("off", False),
("", False),
("random", False),
],
)
def test_is_telemetry_enabled(self, env_value, expected):
with patch.dict(os.environ, {"TINKER_TELEMETRY": env_value}):
assert _is_telemetry_enabled() == expected
def test_is_telemetry_enabled_not_set(self):
with patch.dict(os.environ, {}, clear=True):
assert _is_telemetry_enabled() is True
def test_init_telemetry_enabled(self):
with patch.dict(os.environ, {"TINKER_TELEMETRY": "1"}):
tinker_provider = MockAsyncTinkerProvider()
telemetry = init_telemetry(tinker_provider, session_id="test-session-id")
assert telemetry is not None
assert isinstance(telemetry, Telemetry)
telemetry.stop()
def test_init_telemetry_disabled(self):
with patch.dict(os.environ, {"TINKER_TELEMETRY": "0"}):
tinker_provider = MockAsyncTinkerProvider()
telemetry = init_telemetry(tinker_provider, session_id="test-session-id")
assert telemetry is None
def test_init_telemetry_with_exception(self):
with patch.dict(os.environ, {"TINKER_TELEMETRY": "1"}):
tinker_provider = Mock()
tinker_provider.get_loop.side_effect = Exception("Init error")
with patch("tinker.lib.telemetry.logger") as mock_logger:
telemetry = init_telemetry(tinker_provider, session_id="test-session-id")
assert telemetry is None
mock_logger.warning.assert_called_once()
assert "Error initializing telemetry" in str(mock_logger.warning.call_args)
class TestCaptureExceptionsDecorator:
class MockTelemetryProvider:
def __init__(self):
self.telemetry = Mock()
self.telemetry.capture_exceptions = Mock()
self.telemetry.acapture_exceptions = Mock()
def get_telemetry(self):
return self.telemetry
def test_decorator_on_sync_function(self):
provider = self.MockTelemetryProvider()
@capture_exceptions
def test_func(self):
return "success"
provider.telemetry.capture_exceptions.return_value.__enter__ = Mock()
provider.telemetry.capture_exceptions.return_value.__exit__ = Mock(return_value=False)
result = test_func(provider)
assert result == "success"
provider.telemetry.capture_exceptions.assert_called_once_with(fatal=False, severity="ERROR")
def test_decorator_on_sync_function_with_exception(self):
provider = self.MockTelemetryProvider()
@capture_exceptions(fatal=True, severity="CRITICAL")
def test_func(self):
raise ValueError("Test error")
provider.telemetry.capture_exceptions.return_value.__enter__ = Mock()
provider.telemetry.capture_exceptions.return_value.__exit__ = Mock(return_value=False)
with pytest.raises(ValueError, match="Test error"):
test_func(provider)
provider.telemetry.capture_exceptions.assert_called_once_with(
fatal=True, severity="CRITICAL"
)
@pytest.mark.asyncio
async def test_decorator_on_async_function(self):
provider = self.MockTelemetryProvider()
@capture_exceptions
async def test_func(self):
return "async success"
async_cm = AsyncMock()
async_cm.__aenter__ = AsyncMock(return_value=None)
async_cm.__aexit__ = AsyncMock(return_value=False)
provider.telemetry.acapture_exceptions.return_value = async_cm
result = await test_func(provider)
assert result == "async success"
provider.telemetry.acapture_exceptions.assert_called_once_with(
fatal=False, severity="ERROR"
)
@pytest.mark.asyncio
async def test_decorator_on_async_function_with_exception(self):
provider = self.MockTelemetryProvider()
@capture_exceptions(severity="WARNING")
async def test_func(self):
raise RuntimeError("Async error")
async_cm = AsyncMock()
async_cm.__aenter__ = AsyncMock(return_value=None)
async_cm.__aexit__ = AsyncMock(return_value=False)
provider.telemetry.acapture_exceptions.return_value = async_cm
with pytest.raises(RuntimeError, match="Async error"):
await test_func(provider)
provider.telemetry.acapture_exceptions.assert_called_once_with(
fatal=False, severity="WARNING"
)
def test_decorator_without_telemetry_provider(self):
@capture_exceptions
def test_func():
return "no provider"
result = test_func()
assert result == "no provider"
def test_decorator_with_non_provider_self(self):
class NonProvider:
@capture_exceptions
def test_method(self):
return "not a provider"
obj = NonProvider()
result = obj.test_method()
assert result == "not a provider"
def test_decorator_as_plain_decorator(self):
@capture_exceptions
def test_func():
return "plain decorator"
result = test_func()
assert result == "plain decorator"
def test_decorator_with_parentheses(self):
@capture_exceptions()
def test_func():
return "decorator with parens"
result = test_func()
assert result == "decorator with parens"
def test_decorator_on_inner_sync_function_closing_over_self(self):
provider = self.MockTelemetryProvider()
class Wrapper:
def __init__(self, p: Any):
self.p: Any = p
@capture_exceptions
def outer(self) -> str:
@capture_exceptions
def inner() -> str:
# reference `self` so it is captured in the closure
return "ok" if self else "bad"
return inner()
def get_telemetry(self) -> Telemetry | None:
return self.p.get_telemetry()
wrapper = Wrapper(provider)
provider.telemetry.capture_exceptions.return_value.__enter__ = Mock()
provider.telemetry.capture_exceptions.return_value.__exit__ = Mock(return_value=False)
result = wrapper.outer()
assert result == "ok"
# Called twice: once for outer, once for inner via closure lookup
assert provider.telemetry.capture_exceptions.call_count == 2
@pytest.mark.asyncio
async def test_decorator_on_inner_async_function_closing_over_self(self):
provider = self.MockTelemetryProvider()
class Wrapper:
def __init__(self, p: Any):
self.p: Any = p
@capture_exceptions
async def outer(self) -> str:
@capture_exceptions
async def inner() -> str:
# reference `self` so it is captured in the closure
return "ok-async" if self else "bad"
return await inner()
def get_telemetry(self) -> Telemetry | None:
return self.p.get_telemetry()
wrapper = Wrapper(provider)
async_cm = AsyncMock()
async_cm.__aenter__ = AsyncMock(return_value=None)
async_cm.__aexit__ = AsyncMock(return_value=False)
provider.telemetry.acapture_exceptions.return_value = async_cm
result = await wrapper.outer()
assert result == "ok-async"
# Called twice: once for outer, once for inner via closure lookup
assert provider.telemetry.acapture_exceptions.call_count == 2
class TestTelemetryFlush:
def setup_method(self):
self.tinker_provider = MockAsyncTinkerProvider()
self.telemetry = Telemetry(self.tinker_provider, session_id="test-session-id")
def teardown_method(self):
if hasattr(self, "telemetry"):
self.telemetry.stop()
def test_flush_empty_queue(self):
self.telemetry._queue.clear()
with patch.object(
self.telemetry, "_send_batch_with_retry", new_callable=AsyncMock
) as mock_send:
asyncio.run(self.telemetry._flush())
mock_send.assert_not_called()
def test_flush_small_batch(self):
for _ in range(5):
_ = self.telemetry._log(self.telemetry._session_end_event())
with patch.object(
self.telemetry, "_send_batch_with_retry", new_callable=AsyncMock
) as mock_send:
asyncio.run(self.telemetry._flush())
assert len(self.telemetry._queue) == 0
assert mock_send.call_count == 1
def test_flush_large_batch(self):
for _ in range(MAX_BATCH_SIZE + 10):
_ = self.telemetry._log(self.telemetry._session_end_event())
with patch.object(
self.telemetry, "_send_batch_with_retry", new_callable=AsyncMock
) as mock_send:
asyncio.run(self.telemetry._flush())
assert len(self.telemetry._queue) == 0
assert mock_send.call_count == 2
def test_counters_and_wait(self):
initial_push = self.telemetry._push_counter
for _ in range(3):
_ = self.telemetry._log(self.telemetry._session_end_event())
assert self.telemetry._push_counter == initial_push + 3
with patch.object(
self.telemetry,
"_send_batch_with_retry",
new_callable=AsyncMock,
return_value=TelemetryResponse(status="accepted"),
):
asyncio.run(self.telemetry._flush())
assert self.telemetry._flush_counter >= initial_push + 3
@pytest.mark.asyncio
async def test_log_exception_sync_from_event_loop_protection(self):
with patch("tinker.lib.telemetry._current_loop") as mock_current_loop:
mock_current_loop.return_value = Mock()
try:
raise ValueError("Test error")
except ValueError as e:
result = self.telemetry.log_exception_sync(e)
assert result is True
class TestTelemetryProviderProtocol:
def test_telemetry_provider_protocol(self):
class ValidProvider:
def get_telemetry(self):
return None
class InvalidProvider:
pass
assert isinstance(ValidProvider(), TelemetryProvider)
assert not isinstance(InvalidProvider(), TelemetryProvider)
class TestSyncContextManager:
@pytest.mark.asyncio
async def test_send_batch_uses_sync_context_manager(self):
tinker_provider = MockAsyncTinkerProvider()
telemetry = Telemetry(tinker_provider, session_id="test-session-id")
events = [telemetry._session_start_event()]
batch = telemetry._batch(cast(list[TelemetryEvent], events))
result = await telemetry._send_batch(batch)
tinker_provider._client_cm.__enter__.assert_called_once()
tinker_provider._client_cm.__exit__.assert_called_once()
tinker_provider.telemetry_send_mock.assert_called_once()
assert result.status == "accepted"
telemetry.stop()
class TestCrossLoopSafety:
@pytest.mark.asyncio
async def test_trigger_flush_from_different_loop(self):
tinker_provider = MockAsyncTinkerProvider()
telemetry = Telemetry(tinker_provider, session_id="test-session-id")
tinker_provider.execute_callbacks()
def trigger_from_thread():
telemetry._trigger_flush()
thread = threading.Thread(target=trigger_from_thread)
thread.start()
thread.join()
tinker_provider.execute_callbacks()
assert telemetry._flush_event.is_set()
telemetry.stop()
@pytest.mark.asyncio
async def test_periodic_flush_with_asyncio_event(self):
tinker_provider = MockAsyncTinkerProvider()
telemetry = Telemetry(tinker_provider, session_id="test-session-id")
tinker_provider.execute_callbacks()
telemetry._log(telemetry._session_end_event())
with patch.object(telemetry, "_flush", new_callable=AsyncMock) as mock_flush:
telemetry._flush_event.set()
task = asyncio.create_task(telemetry._periodic_flush())
await asyncio.sleep(0.1)
task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await task
mock_flush.assert_called()
telemetry.stop()

0
src/tinker/py.typed Normal file
View file

View file

@ -0,0 +1,103 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from .models import (
ModelsResource,
AsyncModelsResource,
ModelsResourceWithRawResponse,
AsyncModelsResourceWithRawResponse,
ModelsResourceWithStreamingResponse,
AsyncModelsResourceWithStreamingResponse,
)
from .futures import (
FuturesResource,
AsyncFuturesResource,
FuturesResourceWithRawResponse,
AsyncFuturesResourceWithRawResponse,
FuturesResourceWithStreamingResponse,
AsyncFuturesResourceWithStreamingResponse,
)
from .service import (
ServiceResource,
AsyncServiceResource,
ServiceResourceWithRawResponse,
AsyncServiceResourceWithRawResponse,
ServiceResourceWithStreamingResponse,
AsyncServiceResourceWithStreamingResponse,
)
from .weights import (
WeightsResource,
AsyncWeightsResource,
WeightsResourceWithRawResponse,
AsyncWeightsResourceWithRawResponse,
WeightsResourceWithStreamingResponse,
AsyncWeightsResourceWithStreamingResponse,
)
from .sampling import (
SamplingResource,
AsyncSamplingResource,
SamplingResourceWithRawResponse,
AsyncSamplingResourceWithRawResponse,
SamplingResourceWithStreamingResponse,
AsyncSamplingResourceWithStreamingResponse,
)
from .training import (
TrainingResource,
AsyncTrainingResource,
TrainingResourceWithRawResponse,
AsyncTrainingResourceWithRawResponse,
TrainingResourceWithStreamingResponse,
AsyncTrainingResourceWithStreamingResponse,
)
from .telemetry import (
TelemetryResource,
AsyncTelemetryResource,
TelemetryResourceWithRawResponse,
AsyncTelemetryResourceWithRawResponse,
TelemetryResourceWithStreamingResponse,
AsyncTelemetryResourceWithStreamingResponse,
)
__all__ = [
"ServiceResource",
"AsyncServiceResource",
"ServiceResourceWithRawResponse",
"AsyncServiceResourceWithRawResponse",
"ServiceResourceWithStreamingResponse",
"AsyncServiceResourceWithStreamingResponse",
"TrainingResource",
"AsyncTrainingResource",
"TrainingResourceWithRawResponse",
"AsyncTrainingResourceWithRawResponse",
"TrainingResourceWithStreamingResponse",
"AsyncTrainingResourceWithStreamingResponse",
"ModelsResource",
"AsyncModelsResource",
"ModelsResourceWithRawResponse",
"AsyncModelsResourceWithRawResponse",
"ModelsResourceWithStreamingResponse",
"AsyncModelsResourceWithStreamingResponse",
"WeightsResource",
"AsyncWeightsResource",
"WeightsResourceWithRawResponse",
"AsyncWeightsResourceWithRawResponse",
"WeightsResourceWithStreamingResponse",
"AsyncWeightsResourceWithStreamingResponse",
"SamplingResource",
"AsyncSamplingResource",
"SamplingResourceWithRawResponse",
"AsyncSamplingResourceWithRawResponse",
"SamplingResourceWithStreamingResponse",
"AsyncSamplingResourceWithStreamingResponse",
"FuturesResource",
"AsyncFuturesResource",
"FuturesResourceWithRawResponse",
"AsyncFuturesResourceWithRawResponse",
"FuturesResourceWithStreamingResponse",
"AsyncFuturesResourceWithStreamingResponse",
"TelemetryResource",
"AsyncTelemetryResource",
"TelemetryResourceWithRawResponse",
"AsyncTelemetryResourceWithRawResponse",
"TelemetryResourceWithStreamingResponse",
"AsyncTelemetryResourceWithStreamingResponse",
]

View file

@ -0,0 +1,210 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from __future__ import annotations
from typing import Any, cast
import httpx
from ..types import ModelID, RequestID, future_retrieve_params
from .._types import NOT_GIVEN, Body, Query, Headers, NotGiven
from .._utils import maybe_transform, async_maybe_transform
from .._compat import cached_property
from .._resource import SyncAPIResource, AsyncAPIResource
from .._response import (
to_raw_response_wrapper,
to_streamed_response_wrapper,
async_to_raw_response_wrapper,
async_to_streamed_response_wrapper,
)
from .._base_client import make_request_options
from ..types.model_id import ModelID
from ..types.request_id import RequestID
from ..types.future_retrieve_response import FutureRetrieveResponse
__all__ = ["FuturesResource", "AsyncFuturesResource"]
class FuturesResource(SyncAPIResource):
@cached_property
def with_raw_response(self) -> FuturesResourceWithRawResponse:
"""
This property can be used as a prefix for any HTTP method call to return
the raw response object instead of the parsed content.
For more information, see https://www.github.com/stainless-sdks/tinker-python#accessing-raw-response-data-eg-headers
"""
return FuturesResourceWithRawResponse(self)
@cached_property
def with_streaming_response(self) -> FuturesResourceWithStreamingResponse:
"""
An alternative to `.with_raw_response` that doesn't eagerly read the response body.
For more information, see https://www.github.com/stainless-sdks/tinker-python#with_streaming_response
"""
return FuturesResourceWithStreamingResponse(self)
def retrieve(
self,
*,
request_id: RequestID,
model_id: ModelID | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
idempotency_key: str | None = None,
) -> FutureRetrieveResponse:
"""
Retrieves the result of a future by its ID
Args:
extra_headers: Send extra headers
extra_query: Add additional query parameters to the request
extra_body: Add additional JSON properties to the request
timeout: Override the client-level default timeout for this request, in seconds
idempotency_key: Specify a custom idempotency key for this request
"""
return cast(
FutureRetrieveResponse,
self._post(
"/api/v1/retrieve_future",
body=maybe_transform(
{
"request_id": request_id,
"model_id": model_id,
},
future_retrieve_params.FutureRetrieveParams,
),
options=make_request_options(
extra_headers=extra_headers,
extra_query=extra_query,
extra_body=extra_body,
timeout=timeout,
idempotency_key=idempotency_key,
),
cast_to=cast(
Any, FutureRetrieveResponse
), # Union types cannot be passed in as arguments in the type system
),
)
class AsyncFuturesResource(AsyncAPIResource):
@cached_property
def with_raw_response(self) -> AsyncFuturesResourceWithRawResponse:
"""
This property can be used as a prefix for any HTTP method call to return
the raw response object instead of the parsed content.
For more information, see https://www.github.com/stainless-sdks/tinker-python#accessing-raw-response-data-eg-headers
"""
return AsyncFuturesResourceWithRawResponse(self)
@cached_property
def with_streaming_response(self) -> AsyncFuturesResourceWithStreamingResponse:
"""
An alternative to `.with_raw_response` that doesn't eagerly read the response body.
For more information, see https://www.github.com/stainless-sdks/tinker-python#with_streaming_response
"""
return AsyncFuturesResourceWithStreamingResponse(self)
async def retrieve(
self,
*,
request_id: RequestID,
model_id: ModelID | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
idempotency_key: str | None = None,
max_retries: int | NotGiven = NOT_GIVEN,
) -> FutureRetrieveResponse:
"""
Retrieves the result of a future by its ID
Args:
extra_headers: Send extra headers
extra_query: Add additional query parameters to the request
extra_body: Add additional JSON properties to the request
timeout: Override the client-level default timeout for this request, in seconds
idempotency_key: Specify a custom idempotency key for this request
"""
options=make_request_options(
extra_headers=extra_headers,
extra_query=extra_query,
extra_body=extra_body,
timeout=timeout,
idempotency_key=idempotency_key,
)
if not isinstance(max_retries, NotGiven):
options["max_retries"] = max_retries
return cast(
FutureRetrieveResponse,
await self._post(
"/api/v1/retrieve_future",
body=await async_maybe_transform(
{
"request_id": request_id,
"model_id": model_id,
},
future_retrieve_params.FutureRetrieveParams,
),
options=options,
cast_to=cast(
Any, FutureRetrieveResponse
), # Union types cannot be passed in as arguments in the type system
),
)
class FuturesResourceWithRawResponse:
def __init__(self, futures: FuturesResource) -> None:
self._futures = futures
self.retrieve = to_raw_response_wrapper(
futures.retrieve,
)
class AsyncFuturesResourceWithRawResponse:
def __init__(self, futures: AsyncFuturesResource) -> None:
self._futures = futures
self.retrieve = async_to_raw_response_wrapper(
futures.retrieve,
)
class FuturesResourceWithStreamingResponse:
def __init__(self, futures: FuturesResource) -> None:
self._futures = futures
self.retrieve = to_streamed_response_wrapper(
futures.retrieve,
)
class AsyncFuturesResourceWithStreamingResponse:
def __init__(self, futures: AsyncFuturesResource) -> None:
self._futures = futures
self.retrieve = async_to_streamed_response_wrapper(
futures.retrieve,
)

View file

@ -0,0 +1,412 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from __future__ import annotations
from typing_extensions import Literal
import httpx
from ..types import ModelID, model_create_params, model_unload_params, model_get_info_params
from .._types import NOT_GIVEN, Body, Query, Headers, NotGiven
from .._utils import maybe_transform, async_maybe_transform
from .._compat import cached_property
from .._resource import SyncAPIResource, AsyncAPIResource
from .._response import (
to_raw_response_wrapper,
to_streamed_response_wrapper,
async_to_raw_response_wrapper,
async_to_streamed_response_wrapper,
)
from .._base_client import make_request_options
from ..types.model_id import ModelID
from ..types.get_info_response import GetInfoResponse
from ..types.lora_config_param import LoraConfigParam
from ..types.shared.untyped_api_future import UntypedAPIFuture
__all__ = ["ModelsResource", "AsyncModelsResource"]
class ModelsResource(SyncAPIResource):
@cached_property
def with_raw_response(self) -> ModelsResourceWithRawResponse:
"""
This property can be used as a prefix for any HTTP method call to return
the raw response object instead of the parsed content.
For more information, see https://www.github.com/stainless-sdks/tinker-python#accessing-raw-response-data-eg-headers
"""
return ModelsResourceWithRawResponse(self)
@cached_property
def with_streaming_response(self) -> ModelsResourceWithStreamingResponse:
"""
An alternative to `.with_raw_response` that doesn't eagerly read the response body.
For more information, see https://www.github.com/stainless-sdks/tinker-python#with_streaming_response
"""
return ModelsResourceWithStreamingResponse(self)
def create(
self,
*,
base_model: str,
lora_config: LoraConfigParam | NotGiven = NOT_GIVEN,
type: Literal["create_model"] = "create_model",
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
idempotency_key: str | None = None,
) -> UntypedAPIFuture:
"""Creates a new model.
Pass a LoRA config to create a new LoRA adapter for the
base model.
Args:
extra_headers: Send extra headers
extra_query: Add additional query parameters to the request
extra_body: Add additional JSON properties to the request
timeout: Override the client-level default timeout for this request, in seconds
idempotency_key: Specify a custom idempotency key for this request
"""
return self._post(
"/api/v1/create_model",
body=maybe_transform(
{
"base_model": base_model,
"lora_config": lora_config,
"type": type,
},
model_create_params.ModelCreateParams,
),
options=make_request_options(
extra_headers=extra_headers,
extra_query=extra_query,
extra_body=extra_body,
timeout=timeout,
idempotency_key=idempotency_key,
),
cast_to=UntypedAPIFuture,
)
def get_info(
self,
*,
model_id: ModelID,
type: Literal["get_info"] | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
idempotency_key: str | None = None,
) -> GetInfoResponse:
"""
Retrieves information about the current model
Args:
extra_headers: Send extra headers
extra_query: Add additional query parameters to the request
extra_body: Add additional JSON properties to the request
timeout: Override the client-level default timeout for this request, in seconds
idempotency_key: Specify a custom idempotency key for this request
"""
return self._post(
"/api/v1/get_info",
body=maybe_transform(
{
"model_id": model_id,
"type": type,
},
model_get_info_params.ModelGetInfoParams,
),
options=make_request_options(
extra_headers=extra_headers,
extra_query=extra_query,
extra_body=extra_body,
timeout=timeout,
idempotency_key=idempotency_key,
),
cast_to=GetInfoResponse,
)
def unload(
self,
*,
model_id: ModelID,
type: Literal["unload_model"] | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
idempotency_key: str | None = None,
) -> UntypedAPIFuture:
"""
Unload the model weights and ends the user's session.
Args:
extra_headers: Send extra headers
extra_query: Add additional query parameters to the request
extra_body: Add additional JSON properties to the request
timeout: Override the client-level default timeout for this request, in seconds
idempotency_key: Specify a custom idempotency key for this request
"""
return self._post(
"/api/v1/unload_model",
body=maybe_transform(
{
"model_id": model_id,
"type": type,
},
model_unload_params.ModelUnloadParams,
),
options=make_request_options(
extra_headers=extra_headers,
extra_query=extra_query,
extra_body=extra_body,
timeout=timeout,
idempotency_key=idempotency_key,
),
cast_to=UntypedAPIFuture,
)
class AsyncModelsResource(AsyncAPIResource):
@cached_property
def with_raw_response(self) -> AsyncModelsResourceWithRawResponse:
"""
This property can be used as a prefix for any HTTP method call to return
the raw response object instead of the parsed content.
For more information, see https://www.github.com/stainless-sdks/tinker-python#accessing-raw-response-data-eg-headers
"""
return AsyncModelsResourceWithRawResponse(self)
@cached_property
def with_streaming_response(self) -> AsyncModelsResourceWithStreamingResponse:
"""
An alternative to `.with_raw_response` that doesn't eagerly read the response body.
For more information, see https://www.github.com/stainless-sdks/tinker-python#with_streaming_response
"""
return AsyncModelsResourceWithStreamingResponse(self)
async def create(
self,
*,
base_model: str,
lora_config: LoraConfigParam | NotGiven = NOT_GIVEN,
type: Literal["create_model"] = "create_model",
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
idempotency_key: str | None = None,
) -> UntypedAPIFuture:
"""Creates a new model.
Pass a LoRA config to create a new LoRA adapter for the
base model.
Args:
extra_headers: Send extra headers
extra_query: Add additional query parameters to the request
extra_body: Add additional JSON properties to the request
timeout: Override the client-level default timeout for this request, in seconds
idempotency_key: Specify a custom idempotency key for this request
"""
return await self._post(
"/api/v1/create_model",
body=await async_maybe_transform(
{
"base_model": base_model,
"lora_config": lora_config,
"type": type,
},
model_create_params.ModelCreateParams,
),
options=make_request_options(
extra_headers=extra_headers,
extra_query=extra_query,
extra_body=extra_body,
timeout=timeout,
idempotency_key=idempotency_key,
),
cast_to=UntypedAPIFuture,
)
async def get_info(
self,
*,
model_id: ModelID,
type: Literal["get_info"] | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
idempotency_key: str | None = None,
) -> GetInfoResponse:
"""
Retrieves information about the current model
Args:
extra_headers: Send extra headers
extra_query: Add additional query parameters to the request
extra_body: Add additional JSON properties to the request
timeout: Override the client-level default timeout for this request, in seconds
idempotency_key: Specify a custom idempotency key for this request
"""
return await self._post(
"/api/v1/get_info",
body=await async_maybe_transform(
{
"model_id": model_id,
"type": type,
},
model_get_info_params.ModelGetInfoParams,
),
options=make_request_options(
extra_headers=extra_headers,
extra_query=extra_query,
extra_body=extra_body,
timeout=timeout,
idempotency_key=idempotency_key,
),
cast_to=GetInfoResponse,
)
async def unload(
self,
*,
model_id: ModelID,
type: Literal["unload_model"] | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
idempotency_key: str | None = None,
) -> UntypedAPIFuture:
"""
Unload the model weights and ends the user's session.
Args:
extra_headers: Send extra headers
extra_query: Add additional query parameters to the request
extra_body: Add additional JSON properties to the request
timeout: Override the client-level default timeout for this request, in seconds
idempotency_key: Specify a custom idempotency key for this request
"""
return await self._post(
"/api/v1/unload_model",
body=await async_maybe_transform(
{
"model_id": model_id,
"type": type,
},
model_unload_params.ModelUnloadParams,
),
options=make_request_options(
extra_headers=extra_headers,
extra_query=extra_query,
extra_body=extra_body,
timeout=timeout,
idempotency_key=idempotency_key,
),
cast_to=UntypedAPIFuture,
)
class ModelsResourceWithRawResponse:
def __init__(self, models: ModelsResource) -> None:
self._models = models
self.create = to_raw_response_wrapper(
models.create,
)
self.get_info = to_raw_response_wrapper(
models.get_info,
)
self.unload = to_raw_response_wrapper(
models.unload,
)
class AsyncModelsResourceWithRawResponse:
def __init__(self, models: AsyncModelsResource) -> None:
self._models = models
self.create = async_to_raw_response_wrapper(
models.create,
)
self.get_info = async_to_raw_response_wrapper(
models.get_info,
)
self.unload = async_to_raw_response_wrapper(
models.unload,
)
class ModelsResourceWithStreamingResponse:
def __init__(self, models: ModelsResource) -> None:
self._models = models
self.create = to_streamed_response_wrapper(
models.create,
)
self.get_info = to_streamed_response_wrapper(
models.get_info,
)
self.unload = to_streamed_response_wrapper(
models.unload,
)
class AsyncModelsResourceWithStreamingResponse:
def __init__(self, models: AsyncModelsResource) -> None:
self._models = models
self.create = async_to_streamed_response_wrapper(
models.create,
)
self.get_info = async_to_streamed_response_wrapper(
models.get_info,
)
self.unload = async_to_streamed_response_wrapper(
models.unload,
)

View file

@ -0,0 +1,394 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from __future__ import annotations
from typing_extensions import Literal
import httpx
from ..types import sampling_sample_params, sampling_asample_params
from .._types import NOT_GIVEN, Body, Query, Headers, NotGiven
from .._utils import maybe_transform, async_maybe_transform
from .._compat import cached_property
from .._resource import SyncAPIResource, AsyncAPIResource
from .._response import (
to_raw_response_wrapper,
to_streamed_response_wrapper,
async_to_raw_response_wrapper,
async_to_streamed_response_wrapper,
)
from .._base_client import make_request_options
from ..types.sample_response import SampleResponse
from ..types.model_input_param import ModelInputParam
from ..types.sampling_params_param import SamplingParamsParam
from ..types.shared.untyped_api_future import UntypedAPIFuture
__all__ = ["SamplingResource", "AsyncSamplingResource"]
class SamplingResource(SyncAPIResource):
@cached_property
def with_raw_response(self) -> SamplingResourceWithRawResponse:
"""
This property can be used as a prefix for any HTTP method call to return
the raw response object instead of the parsed content.
For more information, see https://www.github.com/stainless-sdks/tinker-python#accessing-raw-response-data-eg-headers
"""
return SamplingResourceWithRawResponse(self)
@cached_property
def with_streaming_response(self) -> SamplingResourceWithStreamingResponse:
"""
An alternative to `.with_raw_response` that doesn't eagerly read the response body.
For more information, see https://www.github.com/stainless-sdks/tinker-python#with_streaming_response
"""
return SamplingResourceWithStreamingResponse(self)
def asample(
self,
*,
num_samples: int = 1,
prompt: ModelInputParam,
sampling_params: SamplingParamsParam,
base_model: str | NotGiven = NOT_GIVEN,
model_path: str | NotGiven = NOT_GIVEN,
prompt_logprobs: bool | NotGiven = NOT_GIVEN,
type: Literal["sample"] | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
idempotency_key: str | None = None,
) -> UntypedAPIFuture:
"""
Generates samples from the model using the specified sampling parameters
Args:
num_samples: Number of samples to generate
base_model: Optional base model name to sample from. Is inferred from model_path, if
provided. If sampling against a base model, this is required.
model_path: Optional tinker:// path to your model weights or LoRA weights. If not provided,
samples against the base model.
prompt_logprobs: If set to `true`, computes and returns logprobs on the prompt tokens. Defaults
to false.
extra_headers: Send extra headers
extra_query: Add additional query parameters to the request
extra_body: Add additional JSON properties to the request
timeout: Override the client-level default timeout for this request, in seconds
idempotency_key: Specify a custom idempotency key for this request
"""
return self._post(
"/api/v1/asample",
body=maybe_transform(
{
"num_samples": num_samples,
"prompt": prompt,
"sampling_params": sampling_params,
"base_model": base_model,
"model_path": model_path,
"prompt_logprobs": prompt_logprobs,
"type": type,
},
sampling_asample_params.SamplingAsampleParams,
),
options=make_request_options(
extra_headers=extra_headers,
extra_query=extra_query,
extra_body=extra_body,
timeout=timeout,
idempotency_key=idempotency_key,
),
cast_to=UntypedAPIFuture,
)
def sample(
self,
*,
num_samples: int = 1,
prompt: ModelInputParam,
sampling_params: SamplingParamsParam,
base_model: str | NotGiven = NOT_GIVEN,
model_path: str | NotGiven = NOT_GIVEN,
prompt_logprobs: bool | NotGiven = NOT_GIVEN,
type: Literal["sample"] | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
idempotency_key: str | None = None,
) -> SampleResponse:
"""
Generates samples from the model using the specified sampling parameters
Args:
num_samples: Number of samples to generate
base_model: Optional base model name to sample from. Is inferred from model_path, if
provided. If sampling against a base model, this is required.
model_path: Optional tinker:// path to your model weights or LoRA weights. If not provided,
samples against the base model.
prompt_logprobs: If set to `true`, computes and returns logprobs on the prompt tokens. Defaults
to false.
extra_headers: Send extra headers
extra_query: Add additional query parameters to the request
extra_body: Add additional JSON properties to the request
timeout: Override the client-level default timeout for this request, in seconds
idempotency_key: Specify a custom idempotency key for this request
"""
return self._post(
"/api/v1/sample",
body=maybe_transform(
{
"num_samples": num_samples,
"prompt": prompt,
"sampling_params": sampling_params,
"base_model": base_model,
"model_path": model_path,
"prompt_logprobs": prompt_logprobs,
"type": type,
},
sampling_sample_params.SamplingSampleParams,
),
options=make_request_options(
extra_headers=extra_headers,
extra_query=extra_query,
extra_body=extra_body,
timeout=timeout,
idempotency_key=idempotency_key,
),
cast_to=SampleResponse,
)
class AsyncSamplingResource(AsyncAPIResource):
@cached_property
def with_raw_response(self) -> AsyncSamplingResourceWithRawResponse:
"""
This property can be used as a prefix for any HTTP method call to return
the raw response object instead of the parsed content.
For more information, see https://www.github.com/stainless-sdks/tinker-python#accessing-raw-response-data-eg-headers
"""
return AsyncSamplingResourceWithRawResponse(self)
@cached_property
def with_streaming_response(self) -> AsyncSamplingResourceWithStreamingResponse:
"""
An alternative to `.with_raw_response` that doesn't eagerly read the response body.
For more information, see https://www.github.com/stainless-sdks/tinker-python#with_streaming_response
"""
return AsyncSamplingResourceWithStreamingResponse(self)
async def asample(
self,
*,
num_samples: int = 1,
prompt: ModelInputParam,
sampling_params: SamplingParamsParam,
base_model: str | NotGiven = NOT_GIVEN,
model_path: str | NotGiven = NOT_GIVEN,
prompt_logprobs: bool | NotGiven = NOT_GIVEN,
type: Literal["sample"] | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
idempotency_key: str | None = None,
max_retries: int | NotGiven = NOT_GIVEN,
) -> UntypedAPIFuture:
"""
Generates samples from the model using the specified sampling parameters
Args:
num_samples: Number of samples to generate
base_model: Optional base model name to sample from. Is inferred from model_path, if
provided. If sampling against a base model, this is required.
model_path: Optional tinker:// path to your model weights or LoRA weights. If not provided,
samples against the base model.
prompt_logprobs: If set to `true`, computes and returns logprobs on the prompt tokens. Defaults
to false.
extra_headers: Send extra headers
extra_query: Add additional query parameters to the request
extra_body: Add additional JSON properties to the request
timeout: Override the client-level default timeout for this request, in seconds
idempotency_key: Specify a custom idempotency key for this request
"""
options = make_request_options(
extra_headers=extra_headers,
extra_query=extra_query,
extra_body=extra_body,
timeout=timeout,
idempotency_key=idempotency_key,
)
if max_retries is not NOT_GIVEN:
options["max_retries"] = max_retries
return await self._post(
"/api/v1/asample",
body=await async_maybe_transform(
{
"num_samples": num_samples,
"prompt": prompt,
"sampling_params": sampling_params,
"base_model": base_model,
"model_path": model_path,
"prompt_logprobs": prompt_logprobs,
"type": type,
},
sampling_asample_params.SamplingAsampleParams,
),
options=options,
cast_to=UntypedAPIFuture,
)
async def sample(
self,
*,
num_samples: int = 1,
prompt: ModelInputParam,
sampling_params: SamplingParamsParam,
base_model: str | NotGiven = NOT_GIVEN,
model_path: str | NotGiven = NOT_GIVEN,
prompt_logprobs: bool | NotGiven = NOT_GIVEN,
type: Literal["sample"] | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
idempotency_key: str | None = None,
max_retries: int | NotGiven = NOT_GIVEN,
) -> SampleResponse:
"""
Generates samples from the model using the specified sampling parameters
Args:
num_samples: Number of samples to generate
base_model: Optional base model name to sample from. Is inferred from model_path, if
provided. If sampling against a base model, this is required.
model_path: Optional tinker:// path to your model weights or LoRA weights. If not provided,
samples against the base model.
prompt_logprobs: If set to `true`, computes and returns logprobs on the prompt tokens. Defaults
to false.
extra_headers: Send extra headers
extra_query: Add additional query parameters to the request
extra_body: Add additional JSON properties to the request
timeout: Override the client-level default timeout for this request, in seconds
idempotency_key: Specify a custom idempotency key for this request
"""
options = make_request_options(
extra_headers=extra_headers,
extra_query=extra_query,
extra_body=extra_body,
timeout=timeout,
idempotency_key=idempotency_key,
)
if max_retries is not NOT_GIVEN:
options["max_retries"] = max_retries
return await self._post(
"/api/v1/sample",
body=await async_maybe_transform(
{
"num_samples": num_samples,
"prompt": prompt,
"sampling_params": sampling_params,
"base_model": base_model,
"model_path": model_path,
"prompt_logprobs": prompt_logprobs,
"type": type,
},
sampling_sample_params.SamplingSampleParams,
),
options=options,
cast_to=SampleResponse,
)
class SamplingResourceWithRawResponse:
def __init__(self, sampling: SamplingResource) -> None:
self._sampling = sampling
self.asample = to_raw_response_wrapper(
sampling.asample,
)
self.sample = to_raw_response_wrapper(
sampling.sample,
)
class AsyncSamplingResourceWithRawResponse:
def __init__(self, sampling: AsyncSamplingResource) -> None:
self._sampling = sampling
self.asample = async_to_raw_response_wrapper(
sampling.asample,
)
self.sample = async_to_raw_response_wrapper(
sampling.sample,
)
class SamplingResourceWithStreamingResponse:
def __init__(self, sampling: SamplingResource) -> None:
self._sampling = sampling
self.asample = to_streamed_response_wrapper(
sampling.asample,
)
self.sample = to_streamed_response_wrapper(
sampling.sample,
)
class AsyncSamplingResourceWithStreamingResponse:
def __init__(self, sampling: AsyncSamplingResource) -> None:
self._sampling = sampling
self.asample = async_to_streamed_response_wrapper(
sampling.asample,
)
self.sample = async_to_streamed_response_wrapper(
sampling.sample,
)

View file

@ -0,0 +1,186 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from __future__ import annotations
import httpx
from .._types import NOT_GIVEN, Body, Query, Headers, NotGiven
from .._compat import cached_property
from .._resource import SyncAPIResource, AsyncAPIResource
from .._response import (
to_raw_response_wrapper,
to_streamed_response_wrapper,
async_to_raw_response_wrapper,
async_to_streamed_response_wrapper,
)
from .._base_client import make_request_options
from ..types.health_response import HealthResponse
from ..types.get_server_capabilities_response import GetServerCapabilitiesResponse
__all__ = ["ServiceResource", "AsyncServiceResource"]
class ServiceResource(SyncAPIResource):
@cached_property
def with_raw_response(self) -> ServiceResourceWithRawResponse:
"""
This property can be used as a prefix for any HTTP method call to return
the raw response object instead of the parsed content.
For more information, see https://www.github.com/stainless-sdks/tinker-python#accessing-raw-response-data-eg-headers
"""
return ServiceResourceWithRawResponse(self)
@cached_property
def with_streaming_response(self) -> ServiceResourceWithStreamingResponse:
"""
An alternative to `.with_raw_response` that doesn't eagerly read the response body.
For more information, see https://www.github.com/stainless-sdks/tinker-python#with_streaming_response
"""
return ServiceResourceWithStreamingResponse(self)
def get_server_capabilities(
self,
*,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> GetServerCapabilitiesResponse:
"""Retrieves information about supported models and server capabilities"""
return self._get(
"/api/v1/get_server_capabilities",
options=make_request_options(
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
),
cast_to=GetServerCapabilitiesResponse,
)
def health_check(
self,
*,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> HealthResponse:
"""Checks if the API server is ready"""
return self._get(
"/api/v1/healthz",
options=make_request_options(
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
),
cast_to=HealthResponse,
)
class AsyncServiceResource(AsyncAPIResource):
@cached_property
def with_raw_response(self) -> AsyncServiceResourceWithRawResponse:
"""
This property can be used as a prefix for any HTTP method call to return
the raw response object instead of the parsed content.
For more information, see https://www.github.com/stainless-sdks/tinker-python#accessing-raw-response-data-eg-headers
"""
return AsyncServiceResourceWithRawResponse(self)
@cached_property
def with_streaming_response(self) -> AsyncServiceResourceWithStreamingResponse:
"""
An alternative to `.with_raw_response` that doesn't eagerly read the response body.
For more information, see https://www.github.com/stainless-sdks/tinker-python#with_streaming_response
"""
return AsyncServiceResourceWithStreamingResponse(self)
async def get_server_capabilities(
self,
*,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> GetServerCapabilitiesResponse:
"""Retrieves information about supported models and server capabilities"""
return await self._get(
"/api/v1/get_server_capabilities",
options=make_request_options(
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
),
cast_to=GetServerCapabilitiesResponse,
)
async def health_check(
self,
*,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> HealthResponse:
"""Checks if the API server is ready"""
return await self._get(
"/api/v1/healthz",
options=make_request_options(
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
),
cast_to=HealthResponse,
)
class ServiceResourceWithRawResponse:
def __init__(self, service: ServiceResource) -> None:
self._service = service
self.get_server_capabilities = to_raw_response_wrapper(
service.get_server_capabilities,
)
self.health_check = to_raw_response_wrapper(
service.health_check,
)
class AsyncServiceResourceWithRawResponse:
def __init__(self, service: AsyncServiceResource) -> None:
self._service = service
self.get_server_capabilities = async_to_raw_response_wrapper(
service.get_server_capabilities,
)
self.health_check = async_to_raw_response_wrapper(
service.health_check,
)
class ServiceResourceWithStreamingResponse:
def __init__(self, service: ServiceResource) -> None:
self._service = service
self.get_server_capabilities = to_streamed_response_wrapper(
service.get_server_capabilities,
)
self.health_check = to_streamed_response_wrapper(
service.health_check,
)
class AsyncServiceResourceWithStreamingResponse:
def __init__(self, service: AsyncServiceResource) -> None:
self._service = service
self.get_server_capabilities = async_to_streamed_response_wrapper(
service.get_server_capabilities,
)
self.health_check = async_to_streamed_response_wrapper(
service.health_check,
)

View file

@ -0,0 +1,210 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from __future__ import annotations
from typing import Iterable
import httpx
from ..types import telemetry_send_params
from .._types import NOT_GIVEN, Body, Query, Headers, NotGiven
from .._utils import maybe_transform, async_maybe_transform
from .._compat import cached_property
from .._resource import SyncAPIResource, AsyncAPIResource
from .._response import (
to_raw_response_wrapper,
to_streamed_response_wrapper,
async_to_raw_response_wrapper,
async_to_streamed_response_wrapper,
)
from .._base_client import make_request_options
from ..types.telemetry_response import TelemetryResponse
from ..types.telemetry_event_param import TelemetryEventParam
__all__ = ["TelemetryResource", "AsyncTelemetryResource"]
class TelemetryResource(SyncAPIResource):
@cached_property
def with_raw_response(self) -> TelemetryResourceWithRawResponse:
"""
This property can be used as a prefix for any HTTP method call to return
the raw response object instead of the parsed content.
For more information, see https://www.github.com/stainless-sdks/tinker-python#accessing-raw-response-data-eg-headers
"""
return TelemetryResourceWithRawResponse(self)
@cached_property
def with_streaming_response(self) -> TelemetryResourceWithStreamingResponse:
"""
An alternative to `.with_raw_response` that doesn't eagerly read the response body.
For more information, see https://www.github.com/stainless-sdks/tinker-python#with_streaming_response
"""
return TelemetryResourceWithStreamingResponse(self)
def send(
self,
*,
events: Iterable[TelemetryEventParam],
platform: str,
sdk_version: str,
session_id: str,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
idempotency_key: str | None = None,
) -> TelemetryResponse:
"""
Accepts batches of SDK telemetry events for analytics and diagnostics
Args:
platform: Host platform name
sdk_version: SDK version string
extra_headers: Send extra headers
extra_query: Add additional query parameters to the request
extra_body: Add additional JSON properties to the request
timeout: Override the client-level default timeout for this request, in seconds
idempotency_key: Specify a custom idempotency key for this request
"""
return self._post(
"/api/v1/telemetry",
body=maybe_transform(
{
"events": events,
"platform": platform,
"sdk_version": sdk_version,
"session_id": session_id,
},
telemetry_send_params.TelemetrySendParams,
),
options=make_request_options(
extra_headers=extra_headers,
extra_query=extra_query,
extra_body=extra_body,
timeout=timeout,
idempotency_key=idempotency_key,
),
cast_to=TelemetryResponse,
)
class AsyncTelemetryResource(AsyncAPIResource):
@cached_property
def with_raw_response(self) -> AsyncTelemetryResourceWithRawResponse:
"""
This property can be used as a prefix for any HTTP method call to return
the raw response object instead of the parsed content.
For more information, see https://www.github.com/stainless-sdks/tinker-python#accessing-raw-response-data-eg-headers
"""
return AsyncTelemetryResourceWithRawResponse(self)
@cached_property
def with_streaming_response(self) -> AsyncTelemetryResourceWithStreamingResponse:
"""
An alternative to `.with_raw_response` that doesn't eagerly read the response body.
For more information, see https://www.github.com/stainless-sdks/tinker-python#with_streaming_response
"""
return AsyncTelemetryResourceWithStreamingResponse(self)
async def send(
self,
*,
events: Iterable[TelemetryEventParam],
platform: str,
sdk_version: str,
session_id: str,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
idempotency_key: str | None = None,
) -> TelemetryResponse:
"""
Accepts batches of SDK telemetry events for analytics and diagnostics
Args:
platform: Host platform name
sdk_version: SDK version string
extra_headers: Send extra headers
extra_query: Add additional query parameters to the request
extra_body: Add additional JSON properties to the request
timeout: Override the client-level default timeout for this request, in seconds
idempotency_key: Specify a custom idempotency key for this request
"""
return await self._post(
"/api/v1/telemetry",
body=await async_maybe_transform(
{
"events": events,
"platform": platform,
"sdk_version": sdk_version,
"session_id": session_id,
},
telemetry_send_params.TelemetrySendParams,
),
options=make_request_options(
extra_headers=extra_headers,
extra_query=extra_query,
extra_body=extra_body,
timeout=timeout,
idempotency_key=idempotency_key,
),
cast_to=TelemetryResponse,
)
class TelemetryResourceWithRawResponse:
def __init__(self, telemetry: TelemetryResource) -> None:
self._telemetry = telemetry
self.send = to_raw_response_wrapper(
telemetry.send,
)
class AsyncTelemetryResourceWithRawResponse:
def __init__(self, telemetry: AsyncTelemetryResource) -> None:
self._telemetry = telemetry
self.send = async_to_raw_response_wrapper(
telemetry.send,
)
class TelemetryResourceWithStreamingResponse:
def __init__(self, telemetry: TelemetryResource) -> None:
self._telemetry = telemetry
self.send = to_streamed_response_wrapper(
telemetry.send,
)
class AsyncTelemetryResourceWithStreamingResponse:
def __init__(self, telemetry: AsyncTelemetryResource) -> None:
self._telemetry = telemetry
self.send = async_to_streamed_response_wrapper(
telemetry.send,
)

View file

@ -0,0 +1,412 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from __future__ import annotations
from typing_extensions import Literal
import httpx
from ..types import (
ModelID,
training_forward_params,
training_optim_step_params,
training_forward_backward_params,
)
from .._types import NOT_GIVEN, Body, Query, Headers, NotGiven
from .._utils import maybe_transform, async_maybe_transform
from .._compat import cached_property
from .._resource import SyncAPIResource, AsyncAPIResource
from .._response import (
to_raw_response_wrapper,
to_streamed_response_wrapper,
async_to_raw_response_wrapper,
async_to_streamed_response_wrapper,
)
from .._base_client import make_request_options
from ..types.model_id import ModelID
from ..types.shared.untyped_api_future import UntypedAPIFuture
from ..types.forward_backward_input_param import ForwardBackwardInputParam
__all__ = ["TrainingResource", "AsyncTrainingResource"]
class TrainingResource(SyncAPIResource):
@cached_property
def with_raw_response(self) -> TrainingResourceWithRawResponse:
"""
This property can be used as a prefix for any HTTP method call to return
the raw response object instead of the parsed content.
For more information, see https://www.github.com/stainless-sdks/tinker-python#accessing-raw-response-data-eg-headers
"""
return TrainingResourceWithRawResponse(self)
@cached_property
def with_streaming_response(self) -> TrainingResourceWithStreamingResponse:
"""
An alternative to `.with_raw_response` that doesn't eagerly read the response body.
For more information, see https://www.github.com/stainless-sdks/tinker-python#with_streaming_response
"""
return TrainingResourceWithStreamingResponse(self)
def forward(
self,
*,
forward_input: ForwardBackwardInputParam,
model_id: ModelID,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
idempotency_key: str | None = None,
) -> UntypedAPIFuture:
"""
Performs a forward pass through the model
Args:
extra_headers: Send extra headers
extra_query: Add additional query parameters to the request
extra_body: Add additional JSON properties to the request
timeout: Override the client-level default timeout for this request, in seconds
idempotency_key: Specify a custom idempotency key for this request
"""
return self._post(
"/api/v1/forward",
body=maybe_transform(
{
"forward_input": forward_input,
"model_id": model_id,
},
training_forward_params.TrainingForwardParams,
),
options=make_request_options(
extra_headers=extra_headers,
extra_query=extra_query,
extra_body=extra_body,
timeout=timeout,
idempotency_key=idempotency_key,
),
cast_to=UntypedAPIFuture,
)
def forward_backward(
self,
*,
forward_backward_input: ForwardBackwardInputParam,
model_id: ModelID,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
idempotency_key: str | None = None,
) -> UntypedAPIFuture:
"""
Performs a forward and backward pass through the model
Args:
extra_headers: Send extra headers
extra_query: Add additional query parameters to the request
extra_body: Add additional JSON properties to the request
timeout: Override the client-level default timeout for this request, in seconds
idempotency_key: Specify a custom idempotency key for this request
"""
return self._post(
"/api/v1/forward_backward",
body=maybe_transform(
{
"forward_backward_input": forward_backward_input,
"model_id": model_id,
},
training_forward_backward_params.TrainingForwardBackwardParams,
),
options=make_request_options(
extra_headers=extra_headers,
extra_query=extra_query,
extra_body=extra_body,
timeout=timeout,
idempotency_key=idempotency_key,
),
cast_to=UntypedAPIFuture,
)
def optim_step(
self,
*,
adam_params: training_optim_step_params.AdamParams,
model_id: ModelID,
type: Literal["optim_step"] | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
idempotency_key: str | None = None,
) -> UntypedAPIFuture:
"""
Performs an optimization step using AdamW optimizer
Args:
extra_headers: Send extra headers
extra_query: Add additional query parameters to the request
extra_body: Add additional JSON properties to the request
timeout: Override the client-level default timeout for this request, in seconds
idempotency_key: Specify a custom idempotency key for this request
"""
return self._post(
"/api/v1/optim_step",
body=maybe_transform(
{
"adam_params": adam_params,
"model_id": model_id,
"type": type,
},
training_optim_step_params.TrainingOptimStepParams,
),
options=make_request_options(
extra_headers=extra_headers,
extra_query=extra_query,
extra_body=extra_body,
timeout=timeout,
idempotency_key=idempotency_key,
),
cast_to=UntypedAPIFuture,
)
class AsyncTrainingResource(AsyncAPIResource):
@cached_property
def with_raw_response(self) -> AsyncTrainingResourceWithRawResponse:
"""
This property can be used as a prefix for any HTTP method call to return
the raw response object instead of the parsed content.
For more information, see https://www.github.com/stainless-sdks/tinker-python#accessing-raw-response-data-eg-headers
"""
return AsyncTrainingResourceWithRawResponse(self)
@cached_property
def with_streaming_response(self) -> AsyncTrainingResourceWithStreamingResponse:
"""
An alternative to `.with_raw_response` that doesn't eagerly read the response body.
For more information, see https://www.github.com/stainless-sdks/tinker-python#with_streaming_response
"""
return AsyncTrainingResourceWithStreamingResponse(self)
async def forward(
self,
*,
forward_input: ForwardBackwardInputParam,
model_id: ModelID,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
idempotency_key: str | None = None,
) -> UntypedAPIFuture:
"""
Performs a forward pass through the model
Args:
extra_headers: Send extra headers
extra_query: Add additional query parameters to the request
extra_body: Add additional JSON properties to the request
timeout: Override the client-level default timeout for this request, in seconds
idempotency_key: Specify a custom idempotency key for this request
"""
return await self._post(
"/api/v1/forward",
body=await async_maybe_transform(
{
"forward_input": forward_input,
"model_id": model_id,
},
training_forward_params.TrainingForwardParams,
),
options=make_request_options(
extra_headers=extra_headers,
extra_query=extra_query,
extra_body=extra_body,
timeout=timeout,
idempotency_key=idempotency_key,
),
cast_to=UntypedAPIFuture,
)
async def forward_backward(
self,
*,
forward_backward_input: ForwardBackwardInputParam,
model_id: ModelID,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
idempotency_key: str | None = None,
) -> UntypedAPIFuture:
"""
Performs a forward and backward pass through the model
Args:
extra_headers: Send extra headers
extra_query: Add additional query parameters to the request
extra_body: Add additional JSON properties to the request
timeout: Override the client-level default timeout for this request, in seconds
idempotency_key: Specify a custom idempotency key for this request
"""
return await self._post(
"/api/v1/forward_backward",
body=await async_maybe_transform(
{
"forward_backward_input": forward_backward_input,
"model_id": model_id,
},
training_forward_backward_params.TrainingForwardBackwardParams,
),
options=make_request_options(
extra_headers=extra_headers,
extra_query=extra_query,
extra_body=extra_body,
timeout=timeout,
idempotency_key=idempotency_key,
),
cast_to=UntypedAPIFuture,
)
async def optim_step(
self,
*,
adam_params: training_optim_step_params.AdamParams,
model_id: ModelID,
type: Literal["optim_step"] | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
idempotency_key: str | None = None,
) -> UntypedAPIFuture:
"""
Performs an optimization step using AdamW optimizer
Args:
extra_headers: Send extra headers
extra_query: Add additional query parameters to the request
extra_body: Add additional JSON properties to the request
timeout: Override the client-level default timeout for this request, in seconds
idempotency_key: Specify a custom idempotency key for this request
"""
return await self._post(
"/api/v1/optim_step",
body=await async_maybe_transform(
{
"adam_params": adam_params,
"model_id": model_id,
"type": type,
},
training_optim_step_params.TrainingOptimStepParams,
),
options=make_request_options(
extra_headers=extra_headers,
extra_query=extra_query,
extra_body=extra_body,
timeout=timeout,
idempotency_key=idempotency_key,
),
cast_to=UntypedAPIFuture,
)
class TrainingResourceWithRawResponse:
def __init__(self, training: TrainingResource) -> None:
self._training = training
self.forward = to_raw_response_wrapper(
training.forward,
)
self.forward_backward = to_raw_response_wrapper(
training.forward_backward,
)
self.optim_step = to_raw_response_wrapper(
training.optim_step,
)
class AsyncTrainingResourceWithRawResponse:
def __init__(self, training: AsyncTrainingResource) -> None:
self._training = training
self.forward = async_to_raw_response_wrapper(
training.forward,
)
self.forward_backward = async_to_raw_response_wrapper(
training.forward_backward,
)
self.optim_step = async_to_raw_response_wrapper(
training.optim_step,
)
class TrainingResourceWithStreamingResponse:
def __init__(self, training: TrainingResource) -> None:
self._training = training
self.forward = to_streamed_response_wrapper(
training.forward,
)
self.forward_backward = to_streamed_response_wrapper(
training.forward_backward,
)
self.optim_step = to_streamed_response_wrapper(
training.optim_step,
)
class AsyncTrainingResourceWithStreamingResponse:
def __init__(self, training: AsyncTrainingResource) -> None:
self._training = training
self.forward = async_to_streamed_response_wrapper(
training.forward,
)
self.forward_backward = async_to_streamed_response_wrapper(
training.forward_backward,
)
self.optim_step = async_to_streamed_response_wrapper(
training.optim_step,
)

View file

@ -0,0 +1,596 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from __future__ import annotations
from typing_extensions import Literal
import httpx
from ..types import (
ModelID,
CheckpointsListResponse,
weight_load_params,
weight_save_params,
weight_save_for_sampler_params,
)
from .._types import NOT_GIVEN, Body, Query, Headers, NotGiven, NoneType
from .._utils import maybe_transform, async_maybe_transform
from .._compat import cached_property
from .._resource import SyncAPIResource, AsyncAPIResource
from .._response import (
to_raw_response_wrapper,
to_streamed_response_wrapper,
async_to_raw_response_wrapper,
async_to_streamed_response_wrapper,
)
from .._base_client import make_request_options
from ..types.shared.untyped_api_future import UntypedAPIFuture
__all__ = ["WeightsResource", "AsyncWeightsResource"]
class WeightsResource(SyncAPIResource):
@cached_property
def with_raw_response(self) -> WeightsResourceWithRawResponse:
"""
This property can be used as a prefix for any HTTP method call to return
the raw response object instead of the parsed content.
For more information, see https://www.github.com/stainless-sdks/tinker-python#accessing-raw-response-data-eg-headers
"""
return WeightsResourceWithRawResponse(self)
@cached_property
def with_streaming_response(self) -> WeightsResourceWithStreamingResponse:
"""
An alternative to `.with_raw_response` that doesn't eagerly read the response body.
For more information, see https://www.github.com/stainless-sdks/tinker-python#with_streaming_response
"""
return WeightsResourceWithStreamingResponse(self)
def load(
self,
*,
model_id: ModelID,
path: str,
type: Literal["load_weights"] | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
idempotency_key: str | None = None,
) -> UntypedAPIFuture:
"""
Loads model weights from disk
Args:
path: A tinker URI for model weights at a specific step
extra_headers: Send extra headers
extra_query: Add additional query parameters to the request
extra_body: Add additional JSON properties to the request
timeout: Override the client-level default timeout for this request, in seconds
idempotency_key: Specify a custom idempotency key for this request
"""
return self._post(
"/api/v1/load_weights",
body=maybe_transform(
{
"model_id": model_id,
"path": path,
"type": type,
},
weight_load_params.WeightLoadParams,
),
options=make_request_options(
extra_headers=extra_headers,
extra_query=extra_query,
extra_body=extra_body,
timeout=timeout,
idempotency_key=idempotency_key,
),
cast_to=UntypedAPIFuture,
)
def save(
self,
*,
model_id: ModelID,
path: str | NotGiven = NOT_GIVEN,
type: Literal["save_weights"] | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
idempotency_key: str | None = None,
) -> UntypedAPIFuture:
"""
Saves the current model weights to disk
Args:
path: A file/directory name for the weights
extra_headers: Send extra headers
extra_query: Add additional query parameters to the request
extra_body: Add additional JSON properties to the request
timeout: Override the client-level default timeout for this request, in seconds
idempotency_key: Specify a custom idempotency key for this request
"""
return self._post(
"/api/v1/save_weights",
body=maybe_transform(
{
"model_id": model_id,
"path": path,
"type": type,
},
weight_save_params.WeightSaveParams,
),
options=make_request_options(
extra_headers=extra_headers,
extra_query=extra_query,
extra_body=extra_body,
timeout=timeout,
idempotency_key=idempotency_key,
),
cast_to=UntypedAPIFuture,
)
def save_for_sampler(
self,
*,
model_id: ModelID,
path: str | NotGiven = NOT_GIVEN,
type: Literal["save_weights_for_sampler"] | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
idempotency_key: str | None = None,
) -> UntypedAPIFuture:
"""
Saves weights in a format compatible with sampling/inference servers
Args:
path: A file/directory name for the weights
extra_headers: Send extra headers
extra_query: Add additional query parameters to the request
extra_body: Add additional JSON properties to the request
timeout: Override the client-level default timeout for this request, in seconds
idempotency_key: Specify a custom idempotency key for this request
"""
return self._post(
"/api/v1/save_weights_for_sampler",
body=maybe_transform(
{
"model_id": model_id,
"path": path,
"type": type,
},
weight_save_for_sampler_params.WeightSaveForSamplerParams,
),
options=make_request_options(
extra_headers=extra_headers,
extra_query=extra_query,
extra_body=extra_body,
timeout=timeout,
idempotency_key=idempotency_key,
),
cast_to=UntypedAPIFuture,
)
def list(
self,
model_id: ModelID,
*,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> CheckpointsListResponse:
"""
Lists available model checkpoints (both training and sampler)
Args:
model_id: The model ID to list checkpoints for
extra_headers: Send extra headers
extra_query: Add additional query parameters to the request
extra_body: Add additional JSON properties to the request
timeout: Override the client-level default timeout for this request, in seconds
"""
if not model_id:
raise ValueError(f"Expected a non-empty value for `model_id` but received {model_id!r}")
return self._get(
f"/api/v1/models/{model_id}/checkpoints",
options=make_request_options(
extra_headers=extra_headers,
extra_query=extra_query,
extra_body=extra_body,
timeout=timeout,
),
cast_to=CheckpointsListResponse,
)
def delete_checkpoint(
self,
*,
model_id: ModelID,
checkpoint_id: str,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> None:
"""Delete a checkpoint for the given training run."""
if not model_id:
raise ValueError(f"Expected a non-empty value for `model_id` but received {model_id!r}")
if not checkpoint_id:
raise ValueError(
f"Expected a non-empty value for `checkpoint_id` but received {checkpoint_id!r}"
)
self._delete(
f"/api/v1/training_runs/{model_id}/checkpoints/{checkpoint_id}",
options=make_request_options(
extra_headers=extra_headers,
extra_query=extra_query,
extra_body=extra_body,
timeout=timeout,
),
cast_to=NoneType,
)
return None
class AsyncWeightsResource(AsyncAPIResource):
@cached_property
def with_raw_response(self) -> AsyncWeightsResourceWithRawResponse:
"""
This property can be used as a prefix for any HTTP method call to return
the raw response object instead of the parsed content.
For more information, see https://www.github.com/stainless-sdks/tinker-python#accessing-raw-response-data-eg-headers
"""
return AsyncWeightsResourceWithRawResponse(self)
@cached_property
def with_streaming_response(self) -> AsyncWeightsResourceWithStreamingResponse:
"""
An alternative to `.with_raw_response` that doesn't eagerly read the response body.
For more information, see https://www.github.com/stainless-sdks/tinker-python#with_streaming_response
"""
return AsyncWeightsResourceWithStreamingResponse(self)
async def load(
self,
*,
model_id: ModelID,
path: str,
type: Literal["load_weights"] | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
idempotency_key: str | None = None,
) -> UntypedAPIFuture:
"""
Loads model weights from disk
Args:
path: A tinker URI for model weights at a specific step
extra_headers: Send extra headers
extra_query: Add additional query parameters to the request
extra_body: Add additional JSON properties to the request
timeout: Override the client-level default timeout for this request, in seconds
idempotency_key: Specify a custom idempotency key for this request
"""
return await self._post(
"/api/v1/load_weights",
body=await async_maybe_transform(
{
"model_id": model_id,
"path": path,
"type": type,
},
weight_load_params.WeightLoadParams,
),
options=make_request_options(
extra_headers=extra_headers,
extra_query=extra_query,
extra_body=extra_body,
timeout=timeout,
idempotency_key=idempotency_key,
),
cast_to=UntypedAPIFuture,
)
async def save(
self,
*,
model_id: ModelID,
path: str | NotGiven = NOT_GIVEN,
type: Literal["save_weights"] | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
idempotency_key: str | None = None,
) -> UntypedAPIFuture:
"""
Saves the current model weights to disk
Args:
path: A file/directory name for the weights
extra_headers: Send extra headers
extra_query: Add additional query parameters to the request
extra_body: Add additional JSON properties to the request
timeout: Override the client-level default timeout for this request, in seconds
idempotency_key: Specify a custom idempotency key for this request
"""
return await self._post(
"/api/v1/save_weights",
body=await async_maybe_transform(
{
"model_id": model_id,
"path": path,
"type": type,
},
weight_save_params.WeightSaveParams,
),
options=make_request_options(
extra_headers=extra_headers,
extra_query=extra_query,
extra_body=extra_body,
timeout=timeout,
idempotency_key=idempotency_key,
),
cast_to=UntypedAPIFuture,
)
async def save_for_sampler(
self,
*,
model_id: ModelID,
path: str | NotGiven = NOT_GIVEN,
type: Literal["save_weights_for_sampler"] | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
idempotency_key: str | None = None,
) -> UntypedAPIFuture:
"""
Saves weights in a format compatible with sampling/inference servers
Args:
path: A file/directory name for the weights
extra_headers: Send extra headers
extra_query: Add additional query parameters to the request
extra_body: Add additional JSON properties to the request
timeout: Override the client-level default timeout for this request, in seconds
idempotency_key: Specify a custom idempotency key for this request
"""
return await self._post(
"/api/v1/save_weights_for_sampler",
body=await async_maybe_transform(
{
"model_id": model_id,
"path": path,
"type": type,
},
weight_save_for_sampler_params.WeightSaveForSamplerParams,
),
options=make_request_options(
extra_headers=extra_headers,
extra_query=extra_query,
extra_body=extra_body,
timeout=timeout,
idempotency_key=idempotency_key,
),
cast_to=UntypedAPIFuture,
)
async def list(
self,
model_id: ModelID,
*,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> CheckpointsListResponse:
"""
Lists available model checkpoints (both training and sampler)
Args:
model_id: The model ID to list checkpoints for
extra_headers: Send extra headers
extra_query: Add additional query parameters to the request
extra_body: Add additional JSON properties to the request
timeout: Override the client-level default timeout for this request, in seconds
"""
if not model_id:
raise ValueError(f"Expected a non-empty value for `model_id` but received {model_id!r}")
return await self._get(
f"/api/v1/training_runs/{model_id}/checkpoints",
options=make_request_options(
extra_headers=extra_headers,
extra_query=extra_query,
extra_body=extra_body,
timeout=timeout,
),
cast_to=CheckpointsListResponse,
)
async def delete_checkpoint(
self,
*,
model_id: ModelID,
checkpoint_id: str,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> None:
"""Delete a checkpoint for the given training run."""
if not model_id:
raise ValueError(f"Expected a non-empty value for `model_id` but received {model_id!r}")
if not checkpoint_id:
raise ValueError(
f"Expected a non-empty value for `checkpoint_id` but received {checkpoint_id!r}"
)
await self._delete(
f"/api/v1/training_runs/{model_id}/checkpoints/{checkpoint_id}",
options=make_request_options(
extra_headers=extra_headers,
extra_query=extra_query,
extra_body=extra_body,
timeout=timeout,
),
cast_to=NoneType,
)
return None
class WeightsResourceWithRawResponse:
def __init__(self, weights: WeightsResource) -> None:
self._weights = weights
self.load = to_raw_response_wrapper(
weights.load,
)
self.save = to_raw_response_wrapper(
weights.save,
)
self.save_for_sampler = to_raw_response_wrapper(
weights.save_for_sampler,
)
self.list = to_raw_response_wrapper(
weights.list,
)
self.delete_checkpoint = to_raw_response_wrapper(
weights.delete_checkpoint,
)
class AsyncWeightsResourceWithRawResponse:
def __init__(self, weights: AsyncWeightsResource) -> None:
self._weights = weights
self.load = async_to_raw_response_wrapper(
weights.load,
)
self.save = async_to_raw_response_wrapper(
weights.save,
)
self.save_for_sampler = async_to_raw_response_wrapper(
weights.save_for_sampler,
)
self.list = async_to_raw_response_wrapper(
weights.list,
)
self.delete_checkpoint = async_to_raw_response_wrapper(
weights.delete_checkpoint,
)
class WeightsResourceWithStreamingResponse:
def __init__(self, weights: WeightsResource) -> None:
self._weights = weights
self.load = to_streamed_response_wrapper(
weights.load,
)
self.save = to_streamed_response_wrapper(
weights.save,
)
self.save_for_sampler = to_streamed_response_wrapper(
weights.save_for_sampler,
)
self.list = to_streamed_response_wrapper(
weights.list,
)
self.delete_checkpoint = to_streamed_response_wrapper(
weights.delete_checkpoint,
)
class AsyncWeightsResourceWithStreamingResponse:
def __init__(self, weights: AsyncWeightsResource) -> None:
self._weights = weights
self.load = async_to_streamed_response_wrapper(
weights.load,
)
self.save = async_to_streamed_response_wrapper(
weights.save,
)
self.save_for_sampler = async_to_streamed_response_wrapper(
weights.save_for_sampler,
)
self.list = async_to_streamed_response_wrapper(
weights.list,
)
self.delete_checkpoint = async_to_streamed_response_wrapper(
weights.delete_checkpoint,
)

View file

@ -0,0 +1,90 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from __future__ import annotations
# There's an underscore in front of *Param classes (TypedDict) because they shouldn't be used.
from .datum import Datum as Datum
from .shared import UntypedAPIFuture as UntypedAPIFuture
from .model_id import ModelID as ModelID
from .severity import Severity as Severity
from .event_type import EventType as EventType
from .request_id import RequestID as RequestID
from .datum_param import DatumParam as _DatumParam
from .lora_config import LoraConfig as LoraConfig
from .model_input import ModelInput as ModelInput
from .stop_reason import StopReason as StopReason
from .tensor_data import TensorData as TensorData
from .loss_fn_type import LossFnType as LossFnType
from .tensor_dtype import TensorDtype as TensorDtype
from .error_response import ErrorResponse as ErrorResponse
from .loss_fn_inputs import LossFnInputs as LossFnInputs
from .loss_fn_output import LossFnOutput as LossFnOutput
from .sample_request import SampleRequest as SampleRequest
from .health_response import HealthResponse as HealthResponse
from .sample_response import SampleResponse as SampleResponse
from .sampling_params import SamplingParams as SamplingParams
from .telemetry_batch import TelemetryBatch as TelemetryBatch
from .telemetry_event import TelemetryEvent as TelemetryEvent
from .get_info_request import GetInfoRequest as GetInfoRequest
from .sampled_sequence import SampledSequence as SampledSequence
from .get_info_response import GetInfoResponse as GetInfoResponse
from .get_info_response import ModelData as ModelData
from .lora_config_param import LoraConfigParam as _LoraConfigParam
from .model_input_chunk import ModelInputChunk as ModelInputChunk
from .model_input_param import ModelInputParam as _ModelInputParam
from .tensor_data_param import TensorDataParam as _TensorDataParam
from .encoded_text_chunk import EncodedTextChunk as EncodedTextChunk
from .optim_step_request import OptimStepRequest as OptimStepRequest
from .checkpoint import Checkpoint as Checkpoint, CheckpointType as CheckpointType, ParsedCheckpointTinkerPath as ParsedCheckpointTinkerPath
from .weight_load_params import WeightLoadParams as _WeightLoadParams
from .weight_save_params import WeightSaveParams as _WeightSaveParams
from .checkpoints_list_response import CheckpointsListResponse as CheckpointsListResponse
from .cursor import Cursor as Cursor
from .training_run_ids_response import TrainingRunIdsResponse as TrainingRunIdsResponse
from .training_runs_response import TrainingRunsResponse as TrainingRunsResponse
from .forward_backward_input_param import ForwardBackwardInputParam as _ForwardBackwardInputParam
from .forward_backward_input import ForwardBackwardInput as ForwardBackwardInput
from .forward_backward_output import ForwardBackwardOutput as ForwardBackwardOutput
from .model_create_params import ModelCreateParams as _ModelCreateParams
from .model_unload_params import ModelUnloadParams as _ModelUnloadParams
from .session_end_event import SessionEndEvent as SessionEndEvent
from .telemetry_response import TelemetryResponse as TelemetryResponse
from .try_again_response import TryAgainResponse as TryAgainResponse
from .optim_step_response import OptimStepResponse as OptimStepResponse
from .session_start_event import SessionStartEvent as SessionStartEvent
from .create_model_request import CreateModelRequest as CreateModelRequest
from .load_weights_request import LoadWeightsRequest as LoadWeightsRequest
from .loss_fn_inputs_param import LossFnInputsParam as _LossFnInputsParam
from .save_weights_request import SaveWeightsRequest as SaveWeightsRequest
from .unload_model_request import UnloadModelRequest as UnloadModelRequest
from .create_model_response import CreateModelResponse as CreateModelResponse
from .load_weights_response import LoadWeightsResponse as LoadWeightsResponse
from .model_get_info_params import ModelGetInfoParams as _ModelGetInfoParams
from .save_weights_response import SaveWeightsResponse as SaveWeightsResponse
from .telemetry_event_param import TelemetryEventParam as TelemetryEventParam
from .telemetry_send_params import TelemetrySendParams as TelemetrySendParams
from .unload_model_response import UnloadModelResponse as UnloadModelResponse
from .future_retrieve_params import FutureRetrieveParams as _FutureRetrieveParams
from .model_input_chunk_param import ModelInputChunkParam as _ModelInputChunkParam
from .training_forward_params import TrainingForwardParams as _TrainingForwardParams
from .encoded_text_chunk_param import EncodedTextChunkParam as _EncodedTextChunkParam
from .sampling_params_param import SamplingParamsParam as _SamplingParamsParam
from .sampling_sample_params import SamplingSampleParams as _SamplingSampleParams
from .sampling_asample_params import SamplingAsampleParams as _SamplingAsampleParams
from .future_retrieve_response import FutureRetrieveResponse as FutureRetrieveResponse
from .compute_logprobs_response import ComputeLogprobsResponse as ComputeLogprobsResponse
from .image_asset_pointer_chunk import ImageAssetPointerChunk as ImageAssetPointerChunk
from .training_optim_step_params import TrainingOptimStepParams as _TrainingOptimStepParams
from .weight_save_for_sampler_params import WeightSaveForSamplerParams as _WeightSaveForSamplerParams
from .image_asset_pointer_chunk_param import ImageAssetPointerChunkParam as _ImageAssetPointerChunkParam
from .session_end_event_param import SessionEndEventParam as _SessionEndEventParam
from .session_start_event_param import SessionStartEventParam as _SessionStartEventParam
from .unhandled_exception_event import UnhandledExceptionEvent as UnhandledExceptionEvent
from .unhandled_exception_event_param import UnhandledExceptionEventParam as _UnhandledExceptionEventParam
from .get_server_capabilities_response import GetServerCapabilitiesResponse as GetServerCapabilitiesResponse
from .save_weights_for_sampler_request import SaveWeightsForSamplerRequest as SaveWeightsForSamplerRequest
from .get_server_capabilities_response import SupportedModel as SupportedModel
from .training_forward_backward_params import TrainingForwardBackwardParams as _TrainingForwardBackwardParams
from .save_weights_for_sampler_response import SaveWeightsForSamplerResponse as SaveWeightsForSamplerResponse
from .optim_step_request import AdamParams as AdamParams
from .training_run import TrainingRun as TrainingRun

View file

@ -0,0 +1,54 @@
from datetime import datetime
from typing import Literal
from .._models import BaseModel
__all__ = ["Checkpoint", "CheckpointType"]
CheckpointType = Literal["training", "sampler"]
class Checkpoint(BaseModel):
checkpoint_id: str
"""The checkpoint ID"""
checkpoint_type: CheckpointType
"""The type of checkpoint (training or sampler)"""
time: datetime
"""The time when the checkpoint was created"""
tinker_path: str
"""The tinker path to the checkpoint"""
class ParsedCheckpointTinkerPath(BaseModel):
tinker_path: str
"""The tinker path to the checkpoint"""
training_run_id: str
"""The training run ID"""
checkpoint_type: CheckpointType
"""The type of checkpoint (training or sampler)"""
checkpoint_id: str
"""The checkpoint ID"""
@classmethod
def from_tinker_path(cls, tinker_path: str) -> "ParsedCheckpointTinkerPath":
"""Parse a tinker path to an instance of ParsedCheckpointTinkerPath"""
if not tinker_path.startswith("tinker://"):
raise ValueError(f"Invalid tinker path: {tinker_path}")
parts = tinker_path[9:].split("/")
if len(parts) != 3:
raise ValueError(f"Invalid tinker path: {tinker_path}")
if parts[1] not in ["weights", "sampler_weights"]:
raise ValueError(f"Invalid tinker path: {tinker_path}")
checkpoint_type = "training" if parts[1] == "weights" else "sampler"
return cls(
tinker_path=tinker_path,
training_run_id=parts[0],
checkpoint_type=checkpoint_type,
checkpoint_id="/".join(parts[1:]),
)

View file

@ -0,0 +1,9 @@
from .._models import BaseModel
from .checkpoint import Checkpoint
__all__ = ["CheckpointsListResponse"]
class CheckpointsListResponse(BaseModel):
checkpoints: list[Checkpoint]
"""List of available model checkpoints for the model"""

View file

@ -0,0 +1,14 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from typing import Optional
from typing_extensions import Literal, Sequence
from .._models import BaseModel
__all__ = ["ComputeLogprobsResponse"]
class ComputeLogprobsResponse(BaseModel):
logprobs: Sequence[Optional[float]]
type: Literal["compute_logprobs"] = "compute_logprobs"

View file

@ -0,0 +1,17 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from typing import Optional
from typing_extensions import Literal
from .._models import StrictBase
from .lora_config import LoraConfig
__all__ = ["CreateModelRequest"]
class CreateModelRequest(StrictBase):
base_model: str
lora_config: Optional[LoraConfig] = None
type: Literal["create_model"] = "create_model"

View file

@ -0,0 +1,15 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from typing_extensions import Literal
from .._compat import PYDANTIC_V2, ConfigDict
from .._models import BaseModel
from .model_id import ModelID
__all__ = ["CreateModelResponse"]
class CreateModelResponse(BaseModel):
model_id: ModelID
type: Literal["create_model"] = "create_model"

View file

@ -0,0 +1,14 @@
from .._models import BaseModel
__all__ = ["Cursor"]
class Cursor(BaseModel):
offset: int
"""The offset used for pagination"""
limit: int
"""The maximum number of items requested"""
total_count: int
"""The total number of items available"""

66
src/tinker/types/datum.py Normal file
View file

@ -0,0 +1,66 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from typing import Any, TYPE_CHECKING
from pydantic import model_validator
from .._models import StrictBase
from .loss_fn_inputs import LossFnInputs
from .model_input import ModelInput
from .tensor_data import TensorData
try:
import torch # type: ignore[import-not-found]
_HAVE_TORCH = True
except ImportError:
_HAVE_TORCH = False
import numpy as np
if TYPE_CHECKING:
import torch
__all__ = ["Datum"]
class Datum(StrictBase):
loss_fn_inputs: LossFnInputs
"""Dictionary mapping field names to tensor data"""
model_input: ModelInput
@model_validator(mode="before")
@classmethod
def convert_tensors(cls, data: Any) -> Any:
"""Convert torch.Tensor and numpy arrays to TensorData in loss_fn_inputs during construction."""
if isinstance(data, dict) and "loss_fn_inputs" in data:
loss_fn_inputs = data["loss_fn_inputs"]
if isinstance(loss_fn_inputs, dict):
converted_inputs = {}
for key, value in loss_fn_inputs.items():
converted_inputs[key] = cls._maybe_convert_array(key, value)
data = dict(data) # Make a copy
data["loss_fn_inputs"] = converted_inputs
return data
@classmethod
def _maybe_convert_array(cls, key: str, value: Any) -> Any:
"""Convert torch.Tensor, numpy array, or 1-D list to TensorData if needed."""
if _HAVE_TORCH and isinstance(value, torch.Tensor):
return TensorData.from_torch(value)
elif isinstance(value, np.ndarray):
return TensorData.from_numpy(value)
elif isinstance(value, list):
# assume it's 1d and infer the dtype from the key
return TensorData(data=value, dtype=_key_to_type[key], shape=[len(value)])
else:
return value
_key_to_type = {
"target_tokens": "int64",
"weights": "float32",
"advantages": "float32",
"logprobs": "float32",
}

View file

@ -0,0 +1,17 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from __future__ import annotations
from typing_extensions import Required, TypedDict
from .model_input_param import ModelInputParam
from .loss_fn_inputs_param import LossFnInputsParam
__all__ = ["DatumParam"]
class DatumParam(TypedDict, total=False):
loss_fn_inputs: Required[LossFnInputsParam]
"""Dictionary mapping field names to tensor data"""
model_input: Required[ModelInputParam]

View file

@ -0,0 +1,20 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from typing import Sequence
from typing_extensions import Literal
from .._models import StrictBase
__all__ = ["EncodedTextChunk"]
class EncodedTextChunk(StrictBase):
tokens: Sequence[int]
"""Array of token IDs"""
type: Literal["encoded_text"] = "encoded_text"
@property
def length(self) -> int:
return len(self.tokens)

View file

@ -0,0 +1,15 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from __future__ import annotations
from typing import Iterable
from typing_extensions import Literal, Required, TypedDict
__all__ = ["EncodedTextChunkParam"]
class EncodedTextChunkParam(TypedDict, total=False):
tokens: Required[Iterable[int]]
"""Array of token IDs"""
type: Required[Literal["encoded_text"]]

View file

@ -0,0 +1,18 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from typing import Dict, Optional
from .._models import BaseModel
__all__ = ["ErrorResponse"]
class ErrorResponse(BaseModel):
error: str
"""Error code"""
message: str
"""Human-readable error message"""
details: Optional[Dict[str, object]] = None
"""Additional error details"""

View file

@ -0,0 +1,9 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from typing_extensions import Literal, TypeAlias
__all__ = ["EventType"]
EventType: TypeAlias = Literal[
"SESSION_START", "SESSION_END", "UNHANDLED_EXCEPTION", "GENERIC_EVENT"
]

View file

@ -0,0 +1,17 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from typing import List
from .datum import Datum
from .._models import StrictBase
from .loss_fn_type import LossFnType
__all__ = ["ForwardBackwardInput"]
class ForwardBackwardInput(StrictBase):
data: List[Datum]
"""Array of input data for the forward/backward pass"""
loss_fn: LossFnType
"""Fully qualified function path for the loss function"""

View file

@ -0,0 +1,19 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from __future__ import annotations
from typing import Iterable
from typing_extensions import Required, TypedDict
from .datum_param import DatumParam
from .loss_fn_type import LossFnType
__all__ = ["ForwardBackwardInputParam"]
class ForwardBackwardInputParam(TypedDict, total=False):
data: Required[Iterable[DatumParam]]
"""Array of input data for the forward/backward pass"""
loss_fn: Required[LossFnType]
"""Fully qualified function path for the loss function"""

View file

@ -0,0 +1,19 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from typing import Dict, List
from .._models import BaseModel
from .loss_fn_output import LossFnOutput
__all__ = ["ForwardBackwardOutput"]
class ForwardBackwardOutput(BaseModel):
loss_fn_output_type: str
"""The type of the ForwardBackward output. Can be one of [...] TODO"""
loss_fn_outputs: List[LossFnOutput]
"""Dictionary mapping field names to tensor data"""
metrics: Dict[str, float]
"""Training metrics as key-value pairs"""

View file

@ -0,0 +1,16 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from __future__ import annotations
from typing_extensions import Required, TypedDict
from .model_id import ModelID
from .request_id import RequestID
__all__ = ["FutureRetrieveParams"]
class FutureRetrieveParams(TypedDict, total=False):
request_id: Required[RequestID]
model_id: ModelID

View file

@ -0,0 +1,26 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from typing import Union
from typing_extensions import TypeAlias
from .try_again_response import TryAgainResponse
from .optim_step_response import OptimStepResponse
from .create_model_response import CreateModelResponse
from .load_weights_response import LoadWeightsResponse
from .save_weights_response import SaveWeightsResponse
from .unload_model_response import UnloadModelResponse
from .forward_backward_output import ForwardBackwardOutput
from .save_weights_for_sampler_response import SaveWeightsForSamplerResponse
__all__ = ["FutureRetrieveResponse"]
FutureRetrieveResponse: TypeAlias = Union[
TryAgainResponse,
ForwardBackwardOutput,
OptimStepResponse,
SaveWeightsResponse,
LoadWeightsResponse,
SaveWeightsForSamplerResponse,
CreateModelResponse,
UnloadModelResponse,
]

View file

@ -0,0 +1,30 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from datetime import datetime
from typing import Dict
from .._models import BaseModel
from .event_type import EventType
from .severity import Severity
__all__ = ["GenericEvent"]
class GenericEvent(BaseModel):
event: EventType
"""Telemetry event type"""
event_id: str
event_name: str
"""Low-cardinality event name"""
event_session_index: int
severity: Severity
"""Log severity level"""
timestamp: datetime
event_data: Dict[str, object] = {}
"""Arbitrary structured JSON payload"""

View file

@ -0,0 +1,34 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from __future__ import annotations
from datetime import datetime
from typing import Dict, Union
from typing_extensions import Annotated, Required, TypedDict
from .._utils import PropertyInfo
from .event_type import EventType
from .severity import Severity
__all__ = ["GenericEventParam"]
class GenericEventParam(TypedDict, total=False):
event: Required[EventType]
"""Telemetry event type"""
event_id: Required[str]
event_name: Required[str]
"""Low-cardinality event name"""
event_session_index: Required[int]
severity: Required[Severity]
"""Log severity level"""
timestamp: Required[Annotated[Union[str, datetime], PropertyInfo(format="iso8601")]]
event_data: Dict[str, object]
"""Arbitrary structured JSON payload"""

View file

@ -0,0 +1,16 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from typing import Optional
from typing_extensions import Literal
from .._compat import PYDANTIC_V2, ConfigDict
from .._models import StrictBase
from .model_id import ModelID
__all__ = ["GetInfoRequest"]
class GetInfoRequest(StrictBase):
model_id: ModelID
type: Optional[Literal["get_info"]] = None

View file

@ -0,0 +1,32 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from typing import Optional, Literal
from .._compat import PYDANTIC_V2, ConfigDict
from .._models import BaseModel
from .model_id import ModelID
__all__ = ["GetInfoResponse", "ModelData"]
class ModelData(BaseModel):
arch: Optional[str] = None
model_name: Optional[str] = None
class GetInfoResponse(BaseModel):
type: Optional[Literal["get_info"]] = None
model_data: ModelData
model_id: ModelID
is_lora: Optional[bool] = None
lora_rank: Optional[int] = None
model_name: Optional[str] = None
if PYDANTIC_V2:
# allow fields with a `model_` prefix
model_config = ConfigDict(protected_namespaces=tuple())

Some files were not shown because too many files have changed in this diff Show more