mirror of
https://github.com/thinking-machines-lab/tinker.git
synced 2026-04-19 12:58:01 +00:00
Publish Python SDK
Hello world! Signed-off-by: Daniel Xu <dxu@dxuuu.xyz>
This commit is contained in:
commit
829c151ba7
192 changed files with 25717 additions and 0 deletions
8
.devcontainer/Dockerfile
Normal file
8
.devcontainer/Dockerfile
Normal file
|
|
@ -0,0 +1,8 @@
|
|||
ARG VARIANT="3.9"
|
||||
FROM mcr.microsoft.com/vscode/devcontainers/python:0-${VARIANT}
|
||||
|
||||
USER vscode
|
||||
|
||||
COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/
|
||||
|
||||
RUN echo "[[ -d .venv ]] && source .venv/bin/activate || export PATH=\$PATH" >> /home/vscode/.bashrc
|
||||
43
.devcontainer/devcontainer.json
Normal file
43
.devcontainer/devcontainer.json
Normal file
|
|
@ -0,0 +1,43 @@
|
|||
// For format details, see https://aka.ms/devcontainer.json. For config options, see the
|
||||
// README at: https://github.com/devcontainers/templates/tree/main/src/debian
|
||||
{
|
||||
"name": "Debian",
|
||||
"build": {
|
||||
"dockerfile": "Dockerfile",
|
||||
"context": ".."
|
||||
},
|
||||
|
||||
"postStartCommand": "uv sync --all-extras",
|
||||
|
||||
"customizations": {
|
||||
"vscode": {
|
||||
"extensions": [
|
||||
"ms-python.python"
|
||||
],
|
||||
"settings": {
|
||||
"terminal.integrated.shell.linux": "/bin/bash",
|
||||
"python.pythonPath": ".venv/bin/python",
|
||||
"python.defaultInterpreterPath": ".venv/bin/python",
|
||||
"python.typeChecking": "basic",
|
||||
"terminal.integrated.env.linux": {
|
||||
"PATH": "${env:PATH}"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"features": {
|
||||
"ghcr.io/devcontainers/features/node:1": {}
|
||||
}
|
||||
|
||||
// Features to add to the dev container. More info: https://containers.dev/features.
|
||||
// "features": {},
|
||||
|
||||
// Use 'forwardPorts' to make a list of ports inside the container available locally.
|
||||
// "forwardPorts": [],
|
||||
|
||||
// Configure tool-specific properties.
|
||||
// "customizations": {},
|
||||
|
||||
// Uncomment to connect as root instead. More info: https://aka.ms/dev-containers-non-root.
|
||||
// "remoteUser": "root"
|
||||
}
|
||||
15
.gitignore
vendored
Normal file
15
.gitignore
vendored
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
.prism.log
|
||||
_dev
|
||||
|
||||
__pycache__
|
||||
.mypy_cache
|
||||
|
||||
dist
|
||||
|
||||
.venv
|
||||
.idea
|
||||
|
||||
.env
|
||||
.envrc
|
||||
codegen.log
|
||||
Brewfile.lock.json
|
||||
1
.python-version
Normal file
1
.python-version
Normal file
|
|
@ -0,0 +1 @@
|
|||
3.9.18
|
||||
40
.ruff.toml
Normal file
40
.ruff.toml
Normal file
|
|
@ -0,0 +1,40 @@
|
|||
line-length = 100
|
||||
include = []
|
||||
exclude = []
|
||||
force-exclude = true
|
||||
|
||||
[lint]
|
||||
|
||||
# Same as monorepo but removed UP007
|
||||
select = [
|
||||
# https://docs.astral.sh/ruff/rules
|
||||
# pyflakes, pycodestyle, isort
|
||||
"B905", # zip-without-explicit-strict
|
||||
"F",
|
||||
"E",
|
||||
"W",
|
||||
"I001",
|
||||
|
||||
"PIE804", # unnecessary-dict-kwargs
|
||||
"PIE800", # unnecessary-spread
|
||||
"PIE796", # non-unique-enums
|
||||
"PIE794", # duplicate-class-field-definition
|
||||
"PIE807", # reimplemented-container-builtin
|
||||
"PIE810", #multiple-starts-ends-with
|
||||
|
||||
"FLY002",
|
||||
"COM818",
|
||||
"SIM",
|
||||
"Q000",
|
||||
|
||||
#"UP007", # Use X | Y for type annotations
|
||||
# Going to leave these 2 commented out for now as they play interestingly with chz.
|
||||
# We could consider using them though!
|
||||
# "TC001", # typing-only-first-party-import
|
||||
# "TC002", # typing-only-third-party-import
|
||||
"TC004", # runtime-import-in-type-checking-block
|
||||
"TC005", # empty-type-checking-block
|
||||
]
|
||||
ignore = ["E501", "SIM108", "SIM117"]
|
||||
unfixable = ["B905"]
|
||||
extend-safe-fixes = ["TC004", "TC005"]
|
||||
4
.stats.yml
Normal file
4
.stats.yml
Normal file
|
|
@ -0,0 +1,4 @@
|
|||
configured_endpoints: 15
|
||||
openapi_spec_url: https://storage.googleapis.com/stainless-sdk-openapi-specs/thinking-machines%2Ftinker-51afac49fce6fe0323489c3e809c5e2e50ce7384828362dcb0b7c2c7c9d77027.yml
|
||||
openapi_spec_hash: ed32399a1f1754a8e66918aefbe9fd07
|
||||
config_hash: 2ea282a8a1396267cb7c8a5e28467eb8
|
||||
4
.sync_state
Normal file
4
.sync_state
Normal file
|
|
@ -0,0 +1,4 @@
|
|||
{
|
||||
"last_synced_sha": "a3e70a3c6414c4a3d5bae16cc2a612ab65813dd4",
|
||||
"last_sync_time": "2025-10-01T17:30:03.308998"
|
||||
}
|
||||
3
.vscode/settings.json
vendored
Normal file
3
.vscode/settings.json
vendored
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
{
|
||||
"python.analysis.importFormat": "relative",
|
||||
}
|
||||
1
Brewfile
Normal file
1
Brewfile
Normal file
|
|
@ -0,0 +1 @@
|
|||
brew "uv"
|
||||
128
CONTRIBUTING.md
Normal file
128
CONTRIBUTING.md
Normal file
|
|
@ -0,0 +1,128 @@
|
|||
## Setting up the environment
|
||||
|
||||
### With `uv`
|
||||
|
||||
We use [uv](https://docs.astral.sh/uv/) to manage dependencies because it will automatically provision a Python environment with the expected Python version. To set it up, run:
|
||||
|
||||
```sh
|
||||
$ ./scripts/bootstrap
|
||||
```
|
||||
|
||||
Or [install uv manually](https://docs.astral.sh/uv/getting-started/installation/) and run:
|
||||
|
||||
```sh
|
||||
$ uv sync --all-extras
|
||||
```
|
||||
|
||||
You can then run scripts using `uv run python script.py` or by manually activating the virtual environment:
|
||||
|
||||
```sh
|
||||
# manually activate - https://docs.python.org/3/library/venv.html#how-venvs-work
|
||||
$ source .venv/bin/activate
|
||||
|
||||
# now you can omit the `uv run` prefix
|
||||
$ python script.py
|
||||
```
|
||||
|
||||
### Without `uv`
|
||||
|
||||
Alternatively if you don't want to install `uv`, you can stick with the standard `pip` setup by ensuring you have the Python version specified in `.python-version`, create a virtual environment however you desire and then install dependencies using this command:
|
||||
|
||||
```sh
|
||||
$ pip install -r requirements-dev.lock
|
||||
```
|
||||
|
||||
## Modifying/Adding code
|
||||
|
||||
Most of the SDK is generated code. Modifications to code will be persisted between generations, but may
|
||||
result in merge conflicts between manual patches and changes from the generator. The generator will never
|
||||
modify the contents of the `src/tinker/lib/` and `examples/` directories.
|
||||
|
||||
## Adding and running examples
|
||||
|
||||
All files in the `examples/` directory are not modified by the generator and can be freely edited or added to.
|
||||
|
||||
```py
|
||||
# add an example to examples/<your-example>.py
|
||||
|
||||
#!/usr/bin/env -S uv run python
|
||||
…
|
||||
```
|
||||
|
||||
```sh
|
||||
$ chmod +x examples/<your-example>.py
|
||||
# run the example against your api
|
||||
$ ./examples/<your-example>.py
|
||||
```
|
||||
|
||||
## Using the repository from source
|
||||
|
||||
If you’d like to use the repository from source, you can either install from git or link to a cloned repository:
|
||||
|
||||
To install via git:
|
||||
|
||||
```sh
|
||||
$ pip install git+ssh://git@github.com/stainless-sdks/tinker-python.git
|
||||
```
|
||||
|
||||
Alternatively, you can build from source and install the wheel file:
|
||||
|
||||
Building this package will create two files in the `dist/` directory, a `.tar.gz` containing the source files and a `.whl` that can be used to install the package efficiently.
|
||||
|
||||
To create a distributable version of the library, all you have to do is run this command:
|
||||
|
||||
```sh
|
||||
$ uv build
|
||||
# or
|
||||
$ python -m build
|
||||
```
|
||||
|
||||
Then to install:
|
||||
|
||||
```sh
|
||||
$ pip install ./path-to-wheel-file.whl
|
||||
```
|
||||
|
||||
## Running tests
|
||||
|
||||
Most tests require you to [set up a mock server](https://github.com/stoplightio/prism) against the OpenAPI spec to run the tests.
|
||||
|
||||
```sh
|
||||
# you will need npm installed
|
||||
$ npx prism mock path/to/your/openapi.yml
|
||||
```
|
||||
|
||||
```sh
|
||||
$ ./scripts/test
|
||||
```
|
||||
|
||||
## Linting and formatting
|
||||
|
||||
This repository uses [ruff](https://github.com/astral-sh/ruff) and
|
||||
[black](https://github.com/psf/black) to format the code in the repository.
|
||||
|
||||
To lint:
|
||||
|
||||
```sh
|
||||
$ ./scripts/lint
|
||||
```
|
||||
|
||||
To format and fix all ruff issues automatically:
|
||||
|
||||
```sh
|
||||
$ ./scripts/format
|
||||
```
|
||||
|
||||
## Publishing and releases
|
||||
|
||||
Changes made to this repository via the automated release PR pipeline should publish to PyPI automatically. If
|
||||
the changes aren't made through the automated pipeline, you may want to make releases manually.
|
||||
|
||||
### Publish with a GitHub workflow
|
||||
|
||||
You can release to package managers by using [the `Publish PyPI` GitHub action](https://www.github.com/stainless-sdks/tinker-python/actions/workflows/publish-pypi.yml). This requires a setup organization or repository secret to be set up.
|
||||
|
||||
### Publish manually
|
||||
|
||||
If you need to manually release a package, you can run the `bin/publish-pypi` script with a `PYPI_TOKEN` set on
|
||||
the environment.
|
||||
201
LICENSE
Normal file
201
LICENSE
Normal file
|
|
@ -0,0 +1,201 @@
|
|||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright 2025 Tinker
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
7
README.md
Normal file
7
README.md
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
<h1 align="center">Tinker Python SDK</h1>
|
||||
<div align="center">
|
||||
<img src="docs/images/logo.svg" width="60%" />
|
||||
|
||||
Documentation:
|
||||
<a href="http://tinker-docs.thinkingmachines.ai/">tinker-docs.thinkingmachines.ai</a>
|
||||
</div>
|
||||
23
SECURITY.md
Normal file
23
SECURITY.md
Normal file
|
|
@ -0,0 +1,23 @@
|
|||
# Security Policy
|
||||
|
||||
## Reporting Security Issues
|
||||
|
||||
This SDK is generated by [Stainless Software Inc](http://stainless.com). Stainless takes security seriously, and encourages you to report any security vulnerability promptly so that appropriate action can be taken.
|
||||
|
||||
To report a security issue, please contact the Stainless team at security@stainless.com.
|
||||
|
||||
## Responsible Disclosure
|
||||
|
||||
We appreciate the efforts of security researchers and individuals who help us maintain the security of
|
||||
SDKs we generate. If you believe you have found a security vulnerability, please adhere to responsible
|
||||
disclosure practices by allowing us a reasonable amount of time to investigate and address the issue
|
||||
before making any information public.
|
||||
|
||||
## Reporting Non-SDK Related Security Issues
|
||||
|
||||
If you encounter security issues that are not directly related to SDKs but pertain to the services
|
||||
or products provided by Tinker, please follow the respective company's security reporting guidelines.
|
||||
|
||||
---
|
||||
|
||||
Thank you for helping us keep the SDKs and systems they interact with secure.
|
||||
7
bin/publish-pypi
Normal file
7
bin/publish-pypi
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
#!/usr/bin/env bash
|
||||
|
||||
set -eux
|
||||
rm -rf dist
|
||||
mkdir -p dist
|
||||
uv build
|
||||
uv publish --token=$PYPI_TOKEN
|
||||
BIN
docs/images/logo.png
Normal file
BIN
docs/images/logo.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 319 KiB |
967
docs/images/logo.svg
Normal file
967
docs/images/logo.svg
Normal file
|
|
@ -0,0 +1,967 @@
|
|||
<svg width="820" height="512" viewBox="0 0 820 512" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<rect width="820" height="512" fill="white"/>
|
||||
<rect x="618.852" y="46" width="2" height="48" transform="rotate(-45 618.852 46)" fill="#FFCE52"/>
|
||||
<rect x="679.797" y="44.5859" width="2" height="48" transform="rotate(45 679.797 44.5859)" fill="#FFCE52"/>
|
||||
<rect x="233" y="75" width="120" height="2" fill="#73BB83"/>
|
||||
<rect x="463" y="75" width="120" height="2" fill="#73BB83"/>
|
||||
<rect x="233" y="195" width="120" height="2" fill="#73BB83"/>
|
||||
<rect x="463" y="195" width="120" height="2" fill="#73BB83"/>
|
||||
<rect x="110" y="75" width="120" height="2" fill="#FF9D52"/>
|
||||
<rect x="140" y="45" width="60" height="2" fill="#FF9D52"/>
|
||||
<rect x="380" y="45" width="60" height="2" fill="#FF9D52"/>
|
||||
<rect x="620" y="45" width="60" height="2" fill="#FF9D52"/>
|
||||
<rect x="110" y="135" width="120" height="2" fill="#FF9D52"/>
|
||||
<rect x="110" y="195" width="120" height="2" fill="#FF9D52"/>
|
||||
<rect x="169" y="74" width="2" height="120" fill="#4196D8"/>
|
||||
<rect x="230.805" y="134" width="2" height="90" transform="rotate(45 230.805 134)" fill="#D34F4F"/>
|
||||
<rect x="174.234" y="80.4141" width="2" height="80" transform="rotate(-45 174.234 80.4141)" fill="#D34F4F"/>
|
||||
<rect x="235.234" y="82.4141" width="2" height="160" transform="rotate(-45 235.234 82.4141)" fill="#D34F4F"/>
|
||||
<rect x="235.234" y="202.414" width="2" height="160" transform="rotate(-45 235.234 202.414)" fill="#D34F4F"/>
|
||||
<rect x="115.234" y="202.414" width="2" height="160" transform="rotate(-45 115.234 202.414)" fill="#4196D8"/>
|
||||
<rect x="355.234" y="202.414" width="2" height="160" transform="rotate(-45 355.234 202.414)" fill="#4196D8"/>
|
||||
<rect x="595.234" y="202.414" width="2" height="160" transform="rotate(-45 595.234 202.414)" fill="#4196D8"/>
|
||||
<rect x="475.234" y="202.414" width="2" height="160" transform="rotate(-45 475.234 202.414)" fill="#D34F4F"/>
|
||||
<rect x="473.234" y="80.4141" width="2" height="170" transform="rotate(-45 473.234 80.4141)" fill="#D34F4F"/>
|
||||
<rect x="109.234" y="137" width="2" height="80" transform="rotate(-45 109.234 137)" fill="#D34F4F"/>
|
||||
<rect x="167.805" y="75.5859" width="2" height="80" transform="rotate(45 167.805 75.5859)" fill="#D34F4F"/>
|
||||
<rect x="139.797" y="44.5859" width="2" height="48" transform="rotate(45 139.797 44.5859)" fill="#FFCE52"/>
|
||||
<rect x="138.852" y="46" width="2" height="48" transform="rotate(-45 138.852 46)" fill="#FFCE52"/>
|
||||
<rect x="378.852" y="46" width="2" height="48" transform="rotate(-45 378.852 46)" fill="#FFCE52"/>
|
||||
<rect x="198.844" y="46" width="2" height="48" transform="rotate(-45 198.844 46)" fill="#FFCE52"/>
|
||||
<rect x="438.844" y="46" width="2" height="48" transform="rotate(-45 438.844 46)" fill="#FFCE52"/>
|
||||
<rect x="678.844" y="46" width="2" height="48" transform="rotate(-45 678.844 46)" fill="#FFCE52"/>
|
||||
<rect x="199.797" y="44.5859" width="2" height="48" transform="rotate(45 199.797 44.5859)" fill="#FFCE52"/>
|
||||
<rect x="379.797" y="44.5859" width="2" height="48" transform="rotate(45 379.797 44.5859)" fill="#FFCE52"/>
|
||||
<rect x="439.797" y="44.5859" width="2" height="48" transform="rotate(45 439.797 44.5859)" fill="#FFCE52"/>
|
||||
<rect x="619.797" y="44.5859" width="2" height="48" transform="rotate(45 619.797 44.5859)" fill="#FFCE52"/>
|
||||
<rect x="348.805" y="75.5859" width="2" height="170" transform="rotate(45 348.805 75.5859)" fill="#D34F4F"/>
|
||||
<rect x="348.805" y="195.586" width="2" height="170" transform="rotate(45 348.805 195.586)" fill="#D34F4F"/>
|
||||
<rect x="588.805" y="195.586" width="2" height="170" transform="rotate(45 588.805 195.586)" fill="#D34F4F"/>
|
||||
<rect x="588.805" y="75.5859" width="2" height="170" transform="rotate(45 588.805 75.5859)" fill="#D34F4F"/>
|
||||
<rect x="109" y="74" width="2" height="120" fill="#FF9D52"/>
|
||||
<rect x="229" y="74" width="2" height="120" fill="#FF9D52"/>
|
||||
<circle cx="110" cy="76" r="7.5" fill="url(#paint0_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="110" cy="76" r="1" fill="black" fill-opacity="0.4"/>
|
||||
<circle cx="140" cy="76" r="7.5" fill="url(#paint1_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="140" cy="76" r="1" fill="black" fill-opacity="0.4"/>
|
||||
<circle cx="170" cy="76" r="7.5" fill="url(#paint2_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="170" cy="76" r="1" fill="black" fill-opacity="0.4"/>
|
||||
<circle cx="200" cy="76" r="7.5" fill="url(#paint3_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="200" cy="76" r="1" fill="black" fill-opacity="0.4"/>
|
||||
<circle cx="230" cy="76" r="7.5" fill="url(#paint4_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="230" cy="76" r="1" fill="black" fill-opacity="0.4"/>
|
||||
<circle cx="140" cy="136" r="7.5" fill="url(#paint5_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="140" cy="136" r="1" fill="black" fill-opacity="0.4"/>
|
||||
<circle cx="200" cy="136" r="7.5" fill="url(#paint6_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="200" cy="136" r="1" fill="black" fill-opacity="0.4"/>
|
||||
<circle cx="170" cy="136" r="7.5" fill="url(#paint7_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="170" cy="136" r="1" fill="black" fill-opacity="0.4"/>
|
||||
<circle cx="230" cy="106" r="7.5" fill="url(#paint8_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="230" cy="106" r="1" fill="black" fill-opacity="0.4"/>
|
||||
<circle cx="110" cy="106" r="7.5" fill="url(#paint9_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="110" cy="106" r="1" fill="black" fill-opacity="0.4"/>
|
||||
<circle cx="230" cy="136" r="7.5" fill="url(#paint10_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="230" cy="136" r="1" fill="black" fill-opacity="0.4"/>
|
||||
<circle cx="110" cy="136" r="7.5" fill="url(#paint11_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="110" cy="136" r="1" fill="black" fill-opacity="0.4"/>
|
||||
<circle cx="230" cy="166" r="7.5" fill="url(#paint12_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="230" cy="166" r="1" fill="black" fill-opacity="0.4"/>
|
||||
<circle cx="110" cy="166" r="7.5" fill="url(#paint13_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="110" cy="166" r="1" fill="black" fill-opacity="0.4"/>
|
||||
<circle cx="290" cy="136" r="7.5" fill="url(#paint14_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="290" cy="136" r="1" fill="black" fill-opacity="0.4"/>
|
||||
<circle cx="290" cy="256" r="7.5" fill="url(#paint15_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="290" cy="256" r="1" fill="black" fill-opacity="0.4"/>
|
||||
<circle cx="530" cy="256" r="7.5" fill="url(#paint16_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="530" cy="256" r="1" fill="black" fill-opacity="0.4"/>
|
||||
<rect x="350" y="75" width="120" height="2" fill="#FF9D52"/>
|
||||
<rect x="350" y="135" width="120" height="2" fill="#FF9D52"/>
|
||||
<rect x="350" y="195" width="120" height="2" fill="#FF9D52"/>
|
||||
<rect x="409" y="74" width="2" height="120" fill="#4196D8"/>
|
||||
<rect x="470.805" y="134" width="2" height="80" transform="rotate(45 470.805 134)" fill="#D34F4F"/>
|
||||
<rect x="414.234" y="81.4141" width="2" height="80" transform="rotate(-45 414.234 81.4141)" fill="#D34F4F"/>
|
||||
<rect x="349.234" y="137" width="2" height="80" transform="rotate(-45 349.234 137)" fill="#D34F4F"/>
|
||||
<rect x="408.789" y="75.5859" width="2" height="80" transform="rotate(45 408.789 75.5859)" fill="#D34F4F"/>
|
||||
<rect x="349" y="74" width="2" height="120" fill="#FF9D52"/>
|
||||
<rect x="349" y="204" width="2" height="120" fill="#73BB83"/>
|
||||
<rect x="229" y="204" width="2" height="120" fill="#73BB83"/>
|
||||
<rect x="109" y="204" width="2" height="120" fill="#73BB83"/>
|
||||
<rect x="469" y="204" width="2" height="120" fill="#73BB83"/>
|
||||
<rect x="709" y="204" width="2" height="120" fill="#73BB83"/>
|
||||
<rect x="589" y="204" width="2" height="120" fill="#73BB83"/>
|
||||
<rect x="469" y="74" width="2" height="120" fill="#FF9D52"/>
|
||||
<circle cx="350" cy="76" r="7.5" fill="url(#paint17_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="350" cy="76" r="1" fill="black" fill-opacity="0.4"/>
|
||||
<circle cx="380" cy="76" r="7.5" fill="url(#paint18_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="380" cy="76" r="1" fill="black" fill-opacity="0.4"/>
|
||||
<circle cx="410" cy="76" r="7.5" fill="url(#paint19_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="410" cy="76" r="1" fill="black" fill-opacity="0.4"/>
|
||||
<circle cx="440" cy="76" r="7.5" fill="url(#paint20_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="440" cy="76" r="1" fill="black" fill-opacity="0.4"/>
|
||||
<circle cx="470" cy="76" r="7.5" fill="url(#paint21_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="470" cy="76" r="1" fill="black" fill-opacity="0.4"/>
|
||||
<circle cx="380" cy="136" r="7.5" fill="url(#paint22_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="380" cy="136" r="1" fill="black" fill-opacity="0.4"/>
|
||||
<circle cx="440" cy="136" r="7.5" fill="url(#paint23_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="440" cy="136" r="1" fill="black" fill-opacity="0.4"/>
|
||||
<circle cx="410" cy="136" r="7.5" fill="url(#paint24_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="410" cy="136" r="1" fill="black" fill-opacity="0.4"/>
|
||||
<circle cx="470" cy="106" r="7.5" fill="url(#paint25_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="470" cy="106" r="1" fill="black" fill-opacity="0.4"/>
|
||||
<circle cx="350" cy="106" r="7.5" fill="url(#paint26_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="350" cy="106" r="1" fill="black" fill-opacity="0.4"/>
|
||||
<circle cx="470" cy="136" r="7.5" fill="url(#paint27_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="470" cy="136" r="1" fill="black" fill-opacity="0.4"/>
|
||||
<circle cx="350" cy="136" r="7.5" fill="url(#paint28_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="350" cy="136" r="1" fill="black" fill-opacity="0.4"/>
|
||||
<circle cx="470" cy="166" r="7.5" fill="url(#paint29_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="470" cy="166" r="1" fill="black" fill-opacity="0.4"/>
|
||||
<circle cx="350" cy="166" r="7.5" fill="url(#paint30_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="350" cy="166" r="1" fill="black" fill-opacity="0.4"/>
|
||||
<circle cx="350" cy="196" r="7.5" fill="url(#paint31_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="350" cy="196" r="1" fill="black" fill-opacity="0.4"/>
|
||||
<circle cx="380" cy="196" r="7.5" fill="url(#paint32_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="380" cy="196" r="1" fill="black" fill-opacity="0.4"/>
|
||||
<circle cx="410" cy="196" r="7.5" fill="url(#paint33_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="410" cy="196" r="1" fill="black" fill-opacity="0.4"/>
|
||||
<circle cx="440" cy="196" r="7.5" fill="url(#paint34_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="440" cy="196" r="1" fill="black" fill-opacity="0.4"/>
|
||||
<circle cx="470" cy="196" r="7.5" fill="url(#paint35_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="470" cy="196" r="1" fill="black" fill-opacity="0.4"/>
|
||||
<circle cx="530" cy="136" r="7.5" fill="url(#paint36_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="530" cy="136" r="1" fill="black" fill-opacity="0.4"/>
|
||||
<rect x="590" y="75" width="120" height="2" fill="#FF9D52"/>
|
||||
<rect x="590" y="135" width="120" height="2" fill="#FF9D52"/>
|
||||
<rect x="590" y="195" width="120" height="2" fill="#FF9D52"/>
|
||||
<rect x="649" y="74" width="2" height="120" fill="#4196D8"/>
|
||||
<rect x="710.805" y="134" width="2" height="80" transform="rotate(45 710.805 134)" fill="#D34F4F"/>
|
||||
<rect x="654.234" y="81.4141" width="2" height="80" transform="rotate(-45 654.234 81.4141)" fill="#D34F4F"/>
|
||||
<rect x="589.234" y="137" width="2" height="80" transform="rotate(-45 589.234 137)" fill="#D34F4F"/>
|
||||
<rect x="646.086" y="78.5859" width="2" height="80" transform="rotate(45 646.086 78.5859)" fill="#D34F4F"/>
|
||||
<rect x="589" y="74" width="2" height="120" fill="#FF9D52"/>
|
||||
<rect x="709" y="74" width="2" height="120" fill="#FF9D52"/>
|
||||
<circle cx="590" cy="76" r="7.5" fill="url(#paint37_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="590" cy="76" r="1" fill="black" fill-opacity="0.4"/>
|
||||
<circle cx="620" cy="76" r="7.5" fill="url(#paint38_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="620" cy="76" r="1" fill="black" fill-opacity="0.4"/>
|
||||
<circle cx="650" cy="76" r="7.5" fill="url(#paint39_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="650" cy="76" r="1" fill="black" fill-opacity="0.4"/>
|
||||
<circle cx="680" cy="76" r="7.5" fill="url(#paint40_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="680" cy="76" r="1" fill="black" fill-opacity="0.4"/>
|
||||
<circle cx="710" cy="76" r="7.5" fill="url(#paint41_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="710" cy="76" r="1" fill="black" fill-opacity="0.4"/>
|
||||
<circle cx="620" cy="136" r="7.5" fill="url(#paint42_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="620" cy="136" r="1" fill="black" fill-opacity="0.4"/>
|
||||
<circle cx="680" cy="136" r="7.5" fill="url(#paint43_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="680" cy="136" r="1" fill="black" fill-opacity="0.4"/>
|
||||
<circle cx="650" cy="136" r="7.5" fill="url(#paint44_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="650" cy="136" r="1" fill="black" fill-opacity="0.4"/>
|
||||
<circle cx="710" cy="106" r="7.5" fill="url(#paint45_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="710" cy="106" r="1" fill="black" fill-opacity="0.4"/>
|
||||
<circle cx="590" cy="106" r="7.5" fill="url(#paint46_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="590" cy="106" r="1" fill="black" fill-opacity="0.4"/>
|
||||
<circle cx="710" cy="136" r="7.5" fill="url(#paint47_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="710" cy="136" r="1" fill="black" fill-opacity="0.4"/>
|
||||
<circle cx="590" cy="136" r="7.5" fill="url(#paint48_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="590" cy="136" r="1" fill="black" fill-opacity="0.4"/>
|
||||
<circle cx="710" cy="166" r="7.5" fill="url(#paint49_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="710" cy="166" r="1" fill="black" fill-opacity="0.4"/>
|
||||
<circle cx="590" cy="166" r="7.5" fill="url(#paint50_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="590" cy="166" r="1" fill="black" fill-opacity="0.4"/>
|
||||
<rect x="59" y="134" width="2" height="250" fill="#FF9D52"/>
|
||||
<rect x="113.195" y="199.129" width="2" height="80" transform="rotate(140.469 113.195 199.129)" fill="#FFCE52"/>
|
||||
<rect x="59.5391" y="137.977" width="2" height="80" transform="rotate(-140.47 59.5391 137.977)" fill="#FFCE52"/>
|
||||
<rect x="60.5391" y="256.977" width="2" height="80" transform="rotate(-140.47 60.5391 256.977)" fill="#FFCE52"/>
|
||||
<rect x="61.5391" y="375.977" width="2" height="80" transform="rotate(-140.47 61.5391 375.977)" fill="#FFCE52"/>
|
||||
<rect x="113.195" y="439.129" width="2" height="80" transform="rotate(140.469 113.195 439.129)" fill="#FFCE52"/>
|
||||
<rect x="113.195" y="319.129" width="2" height="80" transform="rotate(140.469 113.195 319.129)" fill="#FFCE52"/>
|
||||
<circle cx="60" cy="136" r="7.5" fill="url(#paint51_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="60" cy="136" r="1" fill="black" fill-opacity="0.4"/>
|
||||
<circle cx="60" cy="376" r="7.5" fill="url(#paint52_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="60" cy="376" r="1" fill="black" fill-opacity="0.4"/>
|
||||
<circle cx="60" cy="256" r="7.5" fill="url(#paint53_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="60" cy="256" r="1" fill="black" fill-opacity="0.4"/>
|
||||
<rect x="761.461" y="367.383" width="2" height="230" transform="rotate(-180 761.461 367.383)" fill="#FF9D52"/>
|
||||
<rect x="758.922" y="375.402" width="2" height="80" transform="rotate(39.53 758.922 375.402)" fill="#FFCE52"/>
|
||||
<rect x="757.922" y="256.402" width="2" height="80" transform="rotate(39.53 757.922 256.402)" fill="#FFCE52"/>
|
||||
<rect x="756.922" y="137.406" width="2" height="80" transform="rotate(39.53 756.922 137.406)" fill="#FFCE52"/>
|
||||
<rect x="707.266" y="192.25" width="2" height="80" transform="rotate(-39.5305 707.266 192.25)" fill="#FFCE52"/>
|
||||
<circle cx="110" cy="196" r="7.5" fill="url(#paint54_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="110" cy="196" r="1" fill="black" fill-opacity="0.4"/>
|
||||
<circle cx="140" cy="196" r="7.5" fill="url(#paint55_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="140" cy="196" r="1" fill="black" fill-opacity="0.4"/>
|
||||
<circle cx="170" cy="196" r="7.5" fill="url(#paint56_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="170" cy="196" r="1" fill="black" fill-opacity="0.4"/>
|
||||
<circle cx="200" cy="196" r="7.5" fill="url(#paint57_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="200" cy="196" r="1" fill="black" fill-opacity="0.4"/>
|
||||
<circle cx="230" cy="196" r="7.5" fill="url(#paint58_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="230" cy="196" r="1" fill="black" fill-opacity="0.4"/>
|
||||
<rect x="707.266" y="312.25" width="2" height="80" transform="rotate(-39.5305 707.266 312.25)" fill="#FFCE52"/>
|
||||
<rect x="707.266" y="72.2539" width="2" height="80" transform="rotate(-39.5305 707.266 72.2539)" fill="#FFCE52"/>
|
||||
<rect x="707.266" y="72.2539" width="2" height="80" transform="rotate(-39.5305 707.266 72.2539)" fill="#FFCE52"/>
|
||||
<circle cx="760.461" cy="375.379" r="7.5" transform="rotate(-180 760.461 375.379)" fill="url(#paint59_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="760.461" cy="375.379" r="1" transform="rotate(-180 760.461 375.379)" fill="black" fill-opacity="0.4"/>
|
||||
<circle cx="760.461" cy="255.379" r="7.5" transform="rotate(-180 760.461 255.379)" fill="url(#paint60_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="760.461" cy="255.379" r="1" transform="rotate(-180 760.461 255.379)" fill="black" fill-opacity="0.4"/>
|
||||
<circle cx="760.461" cy="135.383" r="7.5" transform="rotate(-180 760.461 135.383)" fill="url(#paint61_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="760.461" cy="135.383" r="1" transform="rotate(-180 760.461 135.383)" fill="black" fill-opacity="0.4"/>
|
||||
<circle cx="140" cy="46" r="7.5" fill="url(#paint62_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="140" cy="46" r="1" fill="black" fill-opacity="0.4"/>
|
||||
<circle cx="380" cy="46" r="7.5" fill="url(#paint63_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="380" cy="46" r="1" fill="black" fill-opacity="0.4"/>
|
||||
<circle cx="620" cy="46" r="7.5" fill="url(#paint64_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="620" cy="46" r="1" fill="black" fill-opacity="0.4"/>
|
||||
<circle cx="200" cy="46" r="7.5" fill="url(#paint65_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="200" cy="46" r="1" fill="black" fill-opacity="0.4"/>
|
||||
<circle cx="440" cy="46" r="7.5" fill="url(#paint66_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="440" cy="46" r="1" fill="black" fill-opacity="0.4"/>
|
||||
<circle cx="680" cy="46" r="7.5" fill="url(#paint67_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="680" cy="46" r="1" fill="black" fill-opacity="0.4"/>
|
||||
<rect x="201.148" y="466" width="2" height="48" transform="rotate(135 201.148 466)" fill="#FFCE52"/>
|
||||
<rect x="140.203" y="467.414" width="2" height="48" transform="rotate(-135 140.203 467.414)" fill="#FFCE52"/>
|
||||
<rect x="587" y="437" width="120" height="2" transform="rotate(-180 587 437)" fill="#73BB83"/>
|
||||
<rect x="357" y="437" width="120" height="2" transform="rotate(-180 357 437)" fill="#73BB83"/>
|
||||
<rect x="587" y="317" width="120" height="2" transform="rotate(-180 587 317)" fill="#73BB83"/>
|
||||
<rect x="357" y="317" width="120" height="2" transform="rotate(-180 357 317)" fill="#73BB83"/>
|
||||
<rect x="710" y="437" width="120" height="2" transform="rotate(-180 710 437)" fill="#FF9D52"/>
|
||||
<rect x="680" y="467" width="60" height="2" transform="rotate(-180 680 467)" fill="#FF9D52"/>
|
||||
<rect x="440" y="467" width="60" height="2" transform="rotate(-180 440 467)" fill="#FF9D52"/>
|
||||
<rect x="200" y="467" width="60" height="2" transform="rotate(-180 200 467)" fill="#FF9D52"/>
|
||||
<rect x="710" y="377" width="120" height="2" transform="rotate(-180 710 377)" fill="#FF9D52"/>
|
||||
<rect x="710" y="317" width="120" height="2" transform="rotate(-180 710 317)" fill="#FF9D52"/>
|
||||
<rect x="651" y="438" width="2" height="120" transform="rotate(-180 651 438)" fill="#4196D8"/>
|
||||
<rect x="589.195" y="378" width="2" height="90" transform="rotate(-135 589.195 378)" fill="#D34F4F"/>
|
||||
<rect x="645.766" y="431.586" width="2" height="80" transform="rotate(135 645.766 431.586)" fill="#D34F4F"/>
|
||||
<rect x="584.766" y="429.586" width="2" height="160" transform="rotate(135 584.766 429.586)" fill="#D34F4F"/>
|
||||
<rect x="346.766" y="431.586" width="2" height="170" transform="rotate(135 346.766 431.586)" fill="#D34F4F"/>
|
||||
<rect x="710.766" y="375" width="2" height="80" transform="rotate(135 710.766 375)" fill="#D34F4F"/>
|
||||
<rect x="652.195" y="436.414" width="2" height="80" transform="rotate(-135 652.195 436.414)" fill="#D34F4F"/>
|
||||
<rect x="680.203" y="467.414" width="2" height="48" transform="rotate(-135 680.203 467.414)" fill="#FFCE52"/>
|
||||
<rect x="681.148" y="466" width="2" height="48" transform="rotate(135 681.148 466)" fill="#FFCE52"/>
|
||||
<rect x="441.148" y="466" width="2" height="48" transform="rotate(135 441.148 466)" fill="#FFCE52"/>
|
||||
<rect x="621.156" y="466" width="2" height="48" transform="rotate(135 621.156 466)" fill="#FFCE52"/>
|
||||
<rect x="381.156" y="466" width="2" height="48" transform="rotate(135 381.156 466)" fill="#FFCE52"/>
|
||||
<rect x="141.156" y="466" width="2" height="48" transform="rotate(135 141.156 466)" fill="#FFCE52"/>
|
||||
<rect x="620.203" y="467.414" width="2" height="48" transform="rotate(-135 620.203 467.414)" fill="#FFCE52"/>
|
||||
<rect x="440.203" y="467.414" width="2" height="48" transform="rotate(-135 440.203 467.414)" fill="#FFCE52"/>
|
||||
<rect x="380.203" y="467.414" width="2" height="48" transform="rotate(-135 380.203 467.414)" fill="#FFCE52"/>
|
||||
<rect x="200.203" y="467.414" width="2" height="48" transform="rotate(-135 200.203 467.414)" fill="#FFCE52"/>
|
||||
<rect x="471.195" y="436.414" width="2" height="170" transform="rotate(-135 471.195 436.414)" fill="#D34F4F"/>
|
||||
<rect x="231.195" y="436.414" width="2" height="170" transform="rotate(-135 231.195 436.414)" fill="#D34F4F"/>
|
||||
<rect x="711" y="438" width="2" height="120" transform="rotate(-180 711 438)" fill="#FF9D52"/>
|
||||
<rect x="591" y="438" width="2" height="120" transform="rotate(-180 591 438)" fill="#FF9D52"/>
|
||||
<circle cx="710" cy="436" r="7.5" transform="rotate(-180 710 436)" fill="url(#paint68_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="710" cy="436" r="1" transform="rotate(-180 710 436)" fill="black" fill-opacity="0.4"/>
|
||||
<circle cx="680" cy="436" r="7.5" transform="rotate(-180 680 436)" fill="url(#paint69_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="680" cy="436" r="1" transform="rotate(-180 680 436)" fill="black" fill-opacity="0.4"/>
|
||||
<circle cx="650" cy="436" r="7.5" transform="rotate(-180 650 436)" fill="url(#paint70_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="650" cy="436" r="1" transform="rotate(-180 650 436)" fill="black" fill-opacity="0.4"/>
|
||||
<circle cx="620" cy="436" r="7.5" transform="rotate(-180 620 436)" fill="url(#paint71_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="620" cy="436" r="1" transform="rotate(-180 620 436)" fill="black" fill-opacity="0.4"/>
|
||||
<circle cx="590" cy="436" r="7.5" transform="rotate(-180 590 436)" fill="url(#paint72_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="590" cy="436" r="1" transform="rotate(-180 590 436)" fill="black" fill-opacity="0.4"/>
|
||||
<circle cx="680" cy="376" r="7.5" transform="rotate(-180 680 376)" fill="url(#paint73_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="680" cy="376" r="1" transform="rotate(-180 680 376)" fill="black" fill-opacity="0.4"/>
|
||||
<circle cx="620" cy="376" r="7.5" transform="rotate(-180 620 376)" fill="url(#paint74_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="620" cy="376" r="1" transform="rotate(-180 620 376)" fill="black" fill-opacity="0.4"/>
|
||||
<circle cx="650" cy="376" r="7.5" transform="rotate(-180 650 376)" fill="url(#paint75_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="650" cy="376" r="1" transform="rotate(-180 650 376)" fill="black" fill-opacity="0.4"/>
|
||||
<circle cx="590" cy="406" r="7.5" transform="rotate(-180 590 406)" fill="url(#paint76_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="590" cy="406" r="1" transform="rotate(-180 590 406)" fill="black" fill-opacity="0.4"/>
|
||||
<circle cx="710" cy="406" r="7.5" transform="rotate(-180 710 406)" fill="url(#paint77_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="710" cy="406" r="1" transform="rotate(-180 710 406)" fill="black" fill-opacity="0.4"/>
|
||||
<circle cx="590" cy="376" r="7.5" transform="rotate(-180 590 376)" fill="url(#paint78_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="590" cy="376" r="1" transform="rotate(-180 590 376)" fill="black" fill-opacity="0.4"/>
|
||||
<circle cx="710" cy="376" r="7.5" transform="rotate(-180 710 376)" fill="url(#paint79_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="710" cy="376" r="1" transform="rotate(-180 710 376)" fill="black" fill-opacity="0.4"/>
|
||||
<circle cx="590" cy="346" r="7.5" transform="rotate(-180 590 346)" fill="url(#paint80_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="590" cy="346" r="1" transform="rotate(-180 590 346)" fill="black" fill-opacity="0.4"/>
|
||||
<circle cx="710" cy="346" r="7.5" transform="rotate(-180 710 346)" fill="url(#paint81_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="710" cy="346" r="1" transform="rotate(-180 710 346)" fill="black" fill-opacity="0.4"/>
|
||||
<circle cx="710" cy="316" r="7.5" transform="rotate(-180 710 316)" fill="url(#paint82_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="710" cy="316" r="1" transform="rotate(-180 710 316)" fill="black" fill-opacity="0.4"/>
|
||||
<circle cx="680" cy="316" r="7.5" transform="rotate(-180 680 316)" fill="url(#paint83_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="680" cy="316" r="1" transform="rotate(-180 680 316)" fill="black" fill-opacity="0.4"/>
|
||||
<circle cx="650" cy="316" r="7.5" transform="rotate(-180 650 316)" fill="url(#paint84_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="650" cy="316" r="1" transform="rotate(-180 650 316)" fill="black" fill-opacity="0.4"/>
|
||||
<circle cx="620" cy="316" r="7.5" transform="rotate(-180 620 316)" fill="url(#paint85_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="620" cy="316" r="1" transform="rotate(-180 620 316)" fill="black" fill-opacity="0.4"/>
|
||||
<circle cx="590" cy="316" r="7.5" transform="rotate(-180 590 316)" fill="url(#paint86_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="590" cy="316" r="1" transform="rotate(-180 590 316)" fill="black" fill-opacity="0.4"/>
|
||||
<circle cx="530" cy="376" r="7.5" transform="rotate(-180 530 376)" fill="url(#paint87_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="530" cy="376" r="1" transform="rotate(-180 530 376)" fill="black" fill-opacity="0.4"/>
|
||||
<rect x="470" y="437" width="120" height="2" transform="rotate(-180 470 437)" fill="#FF9D52"/>
|
||||
<rect x="470" y="377" width="120" height="2" transform="rotate(-180 470 377)" fill="#FF9D52"/>
|
||||
<rect x="470" y="317" width="120" height="2" transform="rotate(-180 470 317)" fill="#FF9D52"/>
|
||||
<rect x="411" y="438" width="2" height="120" transform="rotate(-180 411 438)" fill="#4196D8"/>
|
||||
<rect x="349.195" y="378" width="2" height="80" transform="rotate(-135 349.195 378)" fill="#D34F4F"/>
|
||||
<rect x="405.766" y="430.586" width="2" height="80" transform="rotate(135 405.766 430.586)" fill="#D34F4F"/>
|
||||
<rect x="470.766" y="375" width="2" height="80" transform="rotate(135 470.766 375)" fill="#D34F4F"/>
|
||||
<rect x="411.211" y="436.414" width="2" height="80" transform="rotate(-135 411.211 436.414)" fill="#D34F4F"/>
|
||||
<rect x="471" y="438" width="2" height="120" transform="rotate(-180 471 438)" fill="#FF9D52"/>
|
||||
<rect x="351" y="438" width="2" height="120" transform="rotate(-180 351 438)" fill="#FF9D52"/>
|
||||
<circle cx="470" cy="436" r="7.5" transform="rotate(-180 470 436)" fill="url(#paint88_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="470" cy="436" r="1" transform="rotate(-180 470 436)" fill="black" fill-opacity="0.4"/>
|
||||
<circle cx="440" cy="436" r="7.5" transform="rotate(-180 440 436)" fill="url(#paint89_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="440" cy="436" r="1" transform="rotate(-180 440 436)" fill="black" fill-opacity="0.4"/>
|
||||
<circle cx="410" cy="436" r="7.5" transform="rotate(-180 410 436)" fill="url(#paint90_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="410" cy="436" r="1" transform="rotate(-180 410 436)" fill="black" fill-opacity="0.4"/>
|
||||
<circle cx="380" cy="436" r="7.5" transform="rotate(-180 380 436)" fill="url(#paint91_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="380" cy="436" r="1" transform="rotate(-180 380 436)" fill="black" fill-opacity="0.4"/>
|
||||
<circle cx="350" cy="436" r="7.5" transform="rotate(-180 350 436)" fill="url(#paint92_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="350" cy="436" r="1" transform="rotate(-180 350 436)" fill="black" fill-opacity="0.4"/>
|
||||
<circle cx="440" cy="376" r="7.5" transform="rotate(-180 440 376)" fill="url(#paint93_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="440" cy="376" r="1" transform="rotate(-180 440 376)" fill="black" fill-opacity="0.4"/>
|
||||
<circle cx="380" cy="376" r="7.5" transform="rotate(-180 380 376)" fill="url(#paint94_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="380" cy="376" r="1" transform="rotate(-180 380 376)" fill="black" fill-opacity="0.4"/>
|
||||
<circle cx="410" cy="376" r="7.5" transform="rotate(-180 410 376)" fill="url(#paint95_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="410" cy="376" r="1" transform="rotate(-180 410 376)" fill="black" fill-opacity="0.4"/>
|
||||
<circle cx="350" cy="406" r="7.5" transform="rotate(-180 350 406)" fill="url(#paint96_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="350" cy="406" r="1" transform="rotate(-180 350 406)" fill="black" fill-opacity="0.4"/>
|
||||
<circle cx="470" cy="406" r="7.5" transform="rotate(-180 470 406)" fill="url(#paint97_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="470" cy="406" r="1" transform="rotate(-180 470 406)" fill="black" fill-opacity="0.4"/>
|
||||
<circle cx="350" cy="376" r="7.5" transform="rotate(-180 350 376)" fill="url(#paint98_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="350" cy="376" r="1" transform="rotate(-180 350 376)" fill="black" fill-opacity="0.4"/>
|
||||
<circle cx="470" cy="376" r="7.5" transform="rotate(-180 470 376)" fill="url(#paint99_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="470" cy="376" r="1" transform="rotate(-180 470 376)" fill="black" fill-opacity="0.4"/>
|
||||
<circle cx="350" cy="346" r="7.5" transform="rotate(-180 350 346)" fill="url(#paint100_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="350" cy="346" r="1" transform="rotate(-180 350 346)" fill="black" fill-opacity="0.4"/>
|
||||
<circle cx="470" cy="346" r="7.5" transform="rotate(-180 470 346)" fill="url(#paint101_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="470" cy="346" r="1" transform="rotate(-180 470 346)" fill="black" fill-opacity="0.4"/>
|
||||
<circle cx="470" cy="316" r="7.5" transform="rotate(-180 470 316)" fill="url(#paint102_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="470" cy="316" r="1" transform="rotate(-180 470 316)" fill="black" fill-opacity="0.4"/>
|
||||
<circle cx="440" cy="316" r="7.5" transform="rotate(-180 440 316)" fill="url(#paint103_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="440" cy="316" r="1" transform="rotate(-180 440 316)" fill="black" fill-opacity="0.4"/>
|
||||
<circle cx="410" cy="316" r="7.5" transform="rotate(-180 410 316)" fill="url(#paint104_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="410" cy="316" r="1" transform="rotate(-180 410 316)" fill="black" fill-opacity="0.4"/>
|
||||
<circle cx="380" cy="316" r="7.5" transform="rotate(-180 380 316)" fill="url(#paint105_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="380" cy="316" r="1" transform="rotate(-180 380 316)" fill="black" fill-opacity="0.4"/>
|
||||
<circle cx="350" cy="316" r="7.5" transform="rotate(-180 350 316)" fill="url(#paint106_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="350" cy="316" r="1" transform="rotate(-180 350 316)" fill="black" fill-opacity="0.4"/>
|
||||
<circle cx="290" cy="376" r="7.5" transform="rotate(-180 290 376)" fill="url(#paint107_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="290" cy="376" r="1" transform="rotate(-180 290 376)" fill="black" fill-opacity="0.4"/>
|
||||
<rect x="230" y="437" width="120" height="2" transform="rotate(-180 230 437)" fill="#FF9D52"/>
|
||||
<rect x="230" y="377" width="120" height="2" transform="rotate(-180 230 377)" fill="#FF9D52"/>
|
||||
<rect x="230" y="317" width="120" height="2" transform="rotate(-180 230 317)" fill="#FF9D52"/>
|
||||
<rect x="171" y="438" width="2" height="120" transform="rotate(-180 171 438)" fill="#4196D8"/>
|
||||
<rect x="109.195" y="378" width="2" height="80" transform="rotate(-135 109.195 378)" fill="#D34F4F"/>
|
||||
<rect x="165.766" y="430.586" width="2" height="80" transform="rotate(135 165.766 430.586)" fill="#D34F4F"/>
|
||||
<rect x="230.766" y="375" width="2" height="80" transform="rotate(135 230.766 375)" fill="#D34F4F"/>
|
||||
<rect x="173.914" y="433.414" width="2" height="80" transform="rotate(-135 173.914 433.414)" fill="#D34F4F"/>
|
||||
<rect x="231" y="438" width="2" height="120" transform="rotate(-180 231 438)" fill="#FF9D52"/>
|
||||
<rect x="111" y="438" width="2" height="120" transform="rotate(-180 111 438)" fill="#FF9D52"/>
|
||||
<circle cx="230" cy="436" r="7.5" transform="rotate(-180 230 436)" fill="url(#paint108_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="230" cy="436" r="1" transform="rotate(-180 230 436)" fill="black" fill-opacity="0.4"/>
|
||||
<circle cx="200" cy="436" r="7.5" transform="rotate(-180 200 436)" fill="url(#paint109_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="200" cy="436" r="1" transform="rotate(-180 200 436)" fill="black" fill-opacity="0.4"/>
|
||||
<circle cx="170" cy="436" r="7.5" transform="rotate(-180 170 436)" fill="url(#paint110_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="170" cy="436" r="1" transform="rotate(-180 170 436)" fill="black" fill-opacity="0.4"/>
|
||||
<circle cx="140" cy="436" r="7.5" transform="rotate(-180 140 436)" fill="url(#paint111_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="140" cy="436" r="1" transform="rotate(-180 140 436)" fill="black" fill-opacity="0.4"/>
|
||||
<circle cx="110" cy="436" r="7.5" transform="rotate(-180 110 436)" fill="url(#paint112_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="110" cy="436" r="1" transform="rotate(-180 110 436)" fill="black" fill-opacity="0.4"/>
|
||||
<circle cx="200" cy="376" r="7.5" transform="rotate(-180 200 376)" fill="url(#paint113_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="200" cy="376" r="1" transform="rotate(-180 200 376)" fill="black" fill-opacity="0.4"/>
|
||||
<circle cx="140" cy="376" r="7.5" transform="rotate(-180 140 376)" fill="url(#paint114_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="140" cy="376" r="1" transform="rotate(-180 140 376)" fill="black" fill-opacity="0.4"/>
|
||||
<circle cx="170" cy="376" r="7.5" transform="rotate(-180 170 376)" fill="url(#paint115_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="170" cy="376" r="1" transform="rotate(-180 170 376)" fill="black" fill-opacity="0.4"/>
|
||||
<circle cx="110" cy="406" r="7.5" transform="rotate(-180 110 406)" fill="url(#paint116_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="110" cy="406" r="1" transform="rotate(-180 110 406)" fill="black" fill-opacity="0.4"/>
|
||||
<circle cx="230" cy="406" r="7.5" transform="rotate(-180 230 406)" fill="url(#paint117_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="230" cy="406" r="1" transform="rotate(-180 230 406)" fill="black" fill-opacity="0.4"/>
|
||||
<circle cx="110" cy="376" r="7.5" transform="rotate(-180 110 376)" fill="url(#paint118_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="110" cy="376" r="1" transform="rotate(-180 110 376)" fill="black" fill-opacity="0.4"/>
|
||||
<circle cx="230" cy="376" r="7.5" transform="rotate(-180 230 376)" fill="url(#paint119_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="230" cy="376" r="1" transform="rotate(-180 230 376)" fill="black" fill-opacity="0.4"/>
|
||||
<circle cx="110" cy="346" r="7.5" transform="rotate(-180 110 346)" fill="url(#paint120_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="110" cy="346" r="1" transform="rotate(-180 110 346)" fill="black" fill-opacity="0.4"/>
|
||||
<circle cx="230" cy="346" r="7.5" transform="rotate(-180 230 346)" fill="url(#paint121_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="230" cy="346" r="1" transform="rotate(-180 230 346)" fill="black" fill-opacity="0.4"/>
|
||||
<circle cx="230" cy="316" r="7.5" transform="rotate(-180 230 316)" fill="url(#paint122_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="230" cy="316" r="1" transform="rotate(-180 230 316)" fill="black" fill-opacity="0.4"/>
|
||||
<circle cx="200" cy="316" r="7.5" transform="rotate(-180 200 316)" fill="url(#paint123_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="200" cy="316" r="1" transform="rotate(-180 200 316)" fill="black" fill-opacity="0.4"/>
|
||||
<circle cx="170" cy="316" r="7.5" transform="rotate(-180 170 316)" fill="url(#paint124_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="170" cy="316" r="1" transform="rotate(-180 170 316)" fill="black" fill-opacity="0.4"/>
|
||||
<circle cx="140" cy="316" r="7.5" transform="rotate(-180 140 316)" fill="url(#paint125_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="140" cy="316" r="1" transform="rotate(-180 140 316)" fill="black" fill-opacity="0.4"/>
|
||||
<circle cx="110" cy="316" r="7.5" transform="rotate(-180 110 316)" fill="url(#paint126_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="110" cy="316" r="1" transform="rotate(-180 110 316)" fill="black" fill-opacity="0.4"/>
|
||||
<circle cx="680" cy="466" r="7.5" transform="rotate(-180 680 466)" fill="url(#paint127_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="680" cy="466" r="1" transform="rotate(-180 680 466)" fill="black" fill-opacity="0.4"/>
|
||||
<circle cx="440" cy="466" r="7.5" transform="rotate(-180 440 466)" fill="url(#paint128_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="440" cy="466" r="1" transform="rotate(-180 440 466)" fill="black" fill-opacity="0.4"/>
|
||||
<circle cx="200" cy="466" r="7.5" transform="rotate(-180 200 466)" fill="url(#paint129_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="200" cy="466" r="1" transform="rotate(-180 200 466)" fill="black" fill-opacity="0.4"/>
|
||||
<circle cx="620" cy="466" r="7.5" transform="rotate(-180 620 466)" fill="url(#paint130_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="620" cy="466" r="1" transform="rotate(-180 620 466)" fill="black" fill-opacity="0.4"/>
|
||||
<circle cx="380" cy="466" r="7.5" transform="rotate(-180 380 466)" fill="url(#paint131_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="380" cy="466" r="1" transform="rotate(-180 380 466)" fill="black" fill-opacity="0.4"/>
|
||||
<circle cx="140" cy="466" r="7.5" transform="rotate(-180 140 466)" fill="url(#paint132_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="140" cy="466" r="1" transform="rotate(-180 140 466)" fill="black" fill-opacity="0.4"/>
|
||||
<circle cx="590" cy="196" r="7.5" fill="url(#paint133_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="590" cy="196" r="1" fill="black" fill-opacity="0.4"/>
|
||||
<circle cx="620" cy="196" r="7.5" fill="url(#paint134_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="620" cy="196" r="1" fill="black" fill-opacity="0.4"/>
|
||||
<circle cx="650" cy="196" r="7.5" fill="url(#paint135_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="650" cy="196" r="1" fill="black" fill-opacity="0.4"/>
|
||||
<circle cx="680" cy="196" r="7.5" fill="url(#paint136_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="680" cy="196" r="1" fill="black" fill-opacity="0.4"/>
|
||||
<circle cx="710" cy="196" r="7.5" fill="url(#paint137_radial_781_319508)" stroke="white"/>
|
||||
<circle cx="710" cy="196" r="1" fill="black" fill-opacity="0.4"/>
|
||||
<defs>
|
||||
<radialGradient id="paint0_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(110 76) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint1_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(140 76) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint2_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(170 76) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint3_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(200 76) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint4_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(230 76) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint5_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(140 136) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint6_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(200 136) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint7_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(170 136) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint8_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(230 106) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint9_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(110 106) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint10_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(230 136) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint11_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(110 136) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint12_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(230 166) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint13_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(110 166) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint14_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(290 136) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint15_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(290 256) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint16_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(530 256) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint17_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(350 76) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint18_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(380 76) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint19_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(410 76) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint20_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(440 76) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint21_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(470 76) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint22_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(380 136) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint23_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(440 136) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint24_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(410 136) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint25_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(470 106) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint26_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(350 106) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint27_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(470 136) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint28_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(350 136) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint29_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(470 166) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint30_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(350 166) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint31_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(350 196) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint32_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(380 196) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint33_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(410 196) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint34_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(440 196) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint35_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(470 196) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint36_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(530 136) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint37_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(590 76) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint38_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(620 76) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint39_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(650 76) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint40_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(680 76) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint41_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(710 76) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint42_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(620 136) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint43_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(680 136) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint44_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(650 136) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint45_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(710 106) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint46_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(590 106) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint47_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(710 136) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint48_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(590 136) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint49_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(710 166) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint50_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(590 166) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint51_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(60 136) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint52_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(60 376) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint53_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(60 256) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint54_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(110 196) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint55_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(140 196) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint56_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(170 196) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint57_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(200 196) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint58_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(230 196) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint59_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(760.461 375.379) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint60_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(760.461 255.379) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint61_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(760.461 135.383) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint62_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(140 46) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint63_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(380 46) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint64_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(620 46) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint65_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(200 46) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint66_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(440 46) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint67_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(680 46) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint68_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(710 436) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint69_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(680 436) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint70_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(650 436) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint71_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(620 436) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint72_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(590 436) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint73_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(680 376) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint74_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(620 376) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint75_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(650 376) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint76_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(590 406) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint77_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(710 406) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint78_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(590 376) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint79_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(710 376) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint80_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(590 346) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint81_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(710 346) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint82_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(710 316) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint83_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(680 316) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint84_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(650 316) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint85_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(620 316) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint86_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(590 316) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint87_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(530 376) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint88_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(470 436) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint89_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(440 436) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint90_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(410 436) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint91_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(380 436) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint92_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(350 436) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint93_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(440 376) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint94_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(380 376) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint95_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(410 376) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint96_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(350 406) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint97_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(470 406) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint98_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(350 376) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint99_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(470 376) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint100_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(350 346) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint101_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(470 346) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint102_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(470 316) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint103_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(440 316) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint104_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(410 316) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint105_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(380 316) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint106_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(350 316) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint107_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(290 376) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint108_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(230 436) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint109_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(200 436) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint110_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(170 436) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint111_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(140 436) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint112_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(110 436) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint113_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(200 376) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint114_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(140 376) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint115_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(170 376) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint116_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(110 406) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint117_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(230 406) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint118_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(110 376) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint119_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(230 376) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint120_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(110 346) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint121_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(230 346) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint122_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(230 316) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint123_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(200 316) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint124_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(170 316) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint125_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(140 316) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint126_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(110 316) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint127_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(680 466) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint128_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(440 466) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint129_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(200 466) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint130_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(620 466) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint131_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(380 466) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint132_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(140 466) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint133_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(590 196) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint134_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(620 196) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint135_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(650 196) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint136_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(680 196) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
<radialGradient id="paint137_radial_781_319508" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(710 196) rotate(90) scale(7.5)">
|
||||
<stop stop-color="#B39351"/>
|
||||
<stop offset="1" stop-color="#FFECC4"/>
|
||||
</radialGradient>
|
||||
</defs>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 71 KiB |
4
examples/.keep
Normal file
4
examples/.keep
Normal file
|
|
@ -0,0 +1,4 @@
|
|||
File generated from our OpenAPI spec by Stainless.
|
||||
|
||||
This directory can be used to store example files demonstrating usage of this SDK.
|
||||
It is ignored by Stainless code generation and its content (other than this keep file) won't be touched.
|
||||
50
mypy.ini
Normal file
50
mypy.ini
Normal file
|
|
@ -0,0 +1,50 @@
|
|||
[mypy]
|
||||
pretty = True
|
||||
show_error_codes = True
|
||||
|
||||
# Exclude _files.py because mypy isn't smart enough to apply
|
||||
# the correct type narrowing and as this is an internal module
|
||||
# it's fine to just use Pyright.
|
||||
#
|
||||
# We also exclude our `tests` as mypy doesn't always infer
|
||||
# types correctly and Pyright will still catch any type errors.
|
||||
exclude = ^(src/tinker/_files\.py|_dev/.*\.py|tests/.*)$
|
||||
|
||||
strict_equality = True
|
||||
implicit_reexport = True
|
||||
check_untyped_defs = True
|
||||
no_implicit_optional = True
|
||||
|
||||
warn_return_any = True
|
||||
warn_unreachable = True
|
||||
warn_unused_configs = True
|
||||
|
||||
# Turn these options off as it could cause conflicts
|
||||
# with the Pyright options.
|
||||
warn_unused_ignores = False
|
||||
warn_redundant_casts = False
|
||||
|
||||
disallow_any_generics = True
|
||||
disallow_untyped_defs = True
|
||||
disallow_untyped_calls = True
|
||||
disallow_subclassing_any = True
|
||||
disallow_incomplete_defs = True
|
||||
disallow_untyped_decorators = True
|
||||
cache_fine_grained = True
|
||||
|
||||
# By default, mypy reports an error if you assign a value to the result
|
||||
# of a function call that doesn't return anything. We do this in our test
|
||||
# cases:
|
||||
# ```
|
||||
# result = ...
|
||||
# assert result is None
|
||||
# ```
|
||||
# Changing this codegen to make mypy happy would increase complexity
|
||||
# and would not be worth it.
|
||||
disable_error_code = func-returns-value,overload-cannot-match
|
||||
|
||||
# https://github.com/python/mypy/issues/12162
|
||||
[mypy.overrides]
|
||||
module = "black.files.*"
|
||||
ignore_errors = true
|
||||
ignore_missing_imports = true
|
||||
185
pyproject.toml
Normal file
185
pyproject.toml
Normal file
|
|
@ -0,0 +1,185 @@
|
|||
[project]
|
||||
name = "tinker"
|
||||
version = "0.1.3"
|
||||
description = "The official Python SDK for the tinker API"
|
||||
readme = "README.md"
|
||||
license = "Apache-2.0"
|
||||
authors = [
|
||||
{ name = "Tinker authors", email = "tinker@thinkingmachines.ai" },
|
||||
]
|
||||
keywords = ["tinker", "machine learning"]
|
||||
dependencies = [
|
||||
"httpx[http2]>=0.23.0, <1",
|
||||
"pydantic>=1.9.0, <3",
|
||||
"typing-extensions>=4.10, <5",
|
||||
"anyio>=3.5.0, <5",
|
||||
"distro>=1.7.0, <2",
|
||||
"sniffio",
|
||||
"numpy",
|
||||
"torch",
|
||||
]
|
||||
requires-python = ">= 3.11"
|
||||
classifiers = [
|
||||
"Typing :: Typed",
|
||||
"Intended Audience :: Developers",
|
||||
"Programming Language :: Python :: 3.9",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
"Programming Language :: Python :: 3.11",
|
||||
"Programming Language :: Python :: 3.12",
|
||||
"Programming Language :: Python :: 3.13",
|
||||
"Operating System :: OS Independent",
|
||||
"Operating System :: POSIX",
|
||||
"Operating System :: MacOS",
|
||||
"Operating System :: POSIX :: Linux",
|
||||
"Operating System :: Microsoft :: Windows",
|
||||
"Topic :: Software Development :: Libraries :: Python Modules",
|
||||
"License :: OSI Approved :: Apache Software License"
|
||||
]
|
||||
|
||||
[project.urls]
|
||||
Homepage = "https://thinkingmachines.ai/tinker"
|
||||
Repository = "https://github.com/thinking-machines-lab/tinker"
|
||||
Documentation = "https://tinker-docs.thinkingmachines.ai/"
|
||||
|
||||
[project.optional-dependencies]
|
||||
aiohttp = ["aiohttp", "httpx_aiohttp>=0.1.8"]
|
||||
|
||||
[tool.uv]
|
||||
managed = true
|
||||
required-version = ">=0.5.0"
|
||||
# version pins are in uv.lock
|
||||
dev-dependencies = [
|
||||
"pyright==1.1.402",
|
||||
"mypy",
|
||||
"respx",
|
||||
"pytest",
|
||||
"pytest-asyncio",
|
||||
"ruff",
|
||||
"time-machine",
|
||||
"dirty-equals>=0.6.0",
|
||||
"importlib-metadata>=6.7.0",
|
||||
"rich>=13.7.1",
|
||||
"nest_asyncio==1.6.0",
|
||||
"pytest-xdist>=3.6.1",
|
||||
]
|
||||
|
||||
[build-system]
|
||||
requires = ["hatchling==1.26.3", "hatch-fancy-pypi-readme"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
[tool.hatch.build]
|
||||
include = [
|
||||
"src/*"
|
||||
]
|
||||
|
||||
[tool.hatch.build.targets.wheel]
|
||||
packages = ["src/tinker"]
|
||||
|
||||
[tool.hatch.build.targets.sdist]
|
||||
# Basically everything except hidden files/directories (such as .github, .devcontainers, .python-version, etc)
|
||||
include = [
|
||||
"/*.toml",
|
||||
"/*.json",
|
||||
"/*.lock",
|
||||
"/*.md",
|
||||
"/mypy.ini",
|
||||
"/noxfile.py",
|
||||
"bin/*",
|
||||
"examples/*",
|
||||
"src/*",
|
||||
"tests/*",
|
||||
]
|
||||
|
||||
[tool.hatch.metadata.hooks.fancy-pypi-readme]
|
||||
content-type = "text/markdown"
|
||||
|
||||
[[tool.hatch.metadata.hooks.fancy-pypi-readme.fragments]]
|
||||
path = "README.md"
|
||||
|
||||
[[tool.hatch.metadata.hooks.fancy-pypi-readme.substitutions]]
|
||||
# replace relative links with absolute links
|
||||
pattern = '\[(.+?)\]\(((?!https?://)\S+?)\)'
|
||||
replacement = '[\1](https://github.com/stainless-sdks/tinker-python/tree/main/\g<2>)'
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
testpaths = ["tests"]
|
||||
addopts = "--tb=short -n auto"
|
||||
xfail_strict = true
|
||||
asyncio_mode = "auto"
|
||||
asyncio_default_fixture_loop_scope = "session"
|
||||
filterwarnings = [
|
||||
"error"
|
||||
]
|
||||
|
||||
[tool.pyright]
|
||||
# this enables practically every flag given by pyright.
|
||||
# there are a couple of flags that are still disabled by
|
||||
# default in strict mode as they are experimental and niche.
|
||||
typeCheckingMode = "strict"
|
||||
pythonVersion = "3.8"
|
||||
|
||||
exclude = [
|
||||
"_dev",
|
||||
".venv",
|
||||
".nox",
|
||||
]
|
||||
|
||||
reportImplicitOverride = true
|
||||
reportOverlappingOverload = false
|
||||
|
||||
reportImportCycles = false
|
||||
reportPrivateUsage = false
|
||||
|
||||
[tool.ruff]
|
||||
line-length = 120
|
||||
output-format = "grouped"
|
||||
target-version = "py38"
|
||||
|
||||
[tool.ruff.format]
|
||||
docstring-code-format = true
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = [
|
||||
# isort
|
||||
"I",
|
||||
# bugbear rules
|
||||
"B",
|
||||
# remove unused imports
|
||||
"F401",
|
||||
# bare except statements
|
||||
"E722",
|
||||
# unused arguments
|
||||
"ARG",
|
||||
# print statements
|
||||
"T201",
|
||||
"T203",
|
||||
# misuse of typing.TYPE_CHECKING
|
||||
"TC004",
|
||||
# import rules
|
||||
"TID251",
|
||||
]
|
||||
ignore = [
|
||||
# mutable defaults
|
||||
"B006",
|
||||
]
|
||||
unfixable = [
|
||||
# disable auto fix for print statements
|
||||
"T201",
|
||||
"T203",
|
||||
]
|
||||
|
||||
[tool.ruff.lint.flake8-tidy-imports.banned-api]
|
||||
"functools.lru_cache".msg = "This function does not retain type information for the wrapped function's arguments; The `lru_cache` function from `_utils` should be used instead"
|
||||
|
||||
[tool.ruff.lint.isort]
|
||||
length-sort = true
|
||||
length-sort-straight = true
|
||||
combine-as-imports = true
|
||||
extra-standard-library = ["typing_extensions"]
|
||||
known-first-party = ["tinker", "tests"]
|
||||
|
||||
[tool.ruff.lint.per-file-ignores]
|
||||
"bin/**.py" = ["T201", "T203"]
|
||||
"scripts/**.py" = ["T201", "T203"]
|
||||
"tests/**.py" = ["T201", "T203"]
|
||||
"examples/**.py" = ["T201", "T203"]
|
||||
105
requirements-dev.lock
Normal file
105
requirements-dev.lock
Normal file
|
|
@ -0,0 +1,105 @@
|
|||
# This file was autogenerated by uv via the following command:
|
||||
# uv export -o requirements-dev.lock --no-hashes
|
||||
-e .
|
||||
annotated-types==0.7.0
|
||||
# via pydantic
|
||||
anyio==4.5.2 ; python_full_version < '3.9'
|
||||
# via
|
||||
# httpx
|
||||
# tinker
|
||||
anyio==4.8.0 ; python_full_version >= '3.9'
|
||||
# via
|
||||
# httpx
|
||||
# tinker
|
||||
certifi==2024.12.14
|
||||
# via
|
||||
# httpcore
|
||||
# httpx
|
||||
colorama==0.4.6 ; sys_platform == 'win32'
|
||||
# via pytest
|
||||
dirty-equals==0.9.0
|
||||
distro==1.9.0
|
||||
# via tinker
|
||||
exceptiongroup==1.2.2 ; python_full_version < '3.11'
|
||||
# via
|
||||
# anyio
|
||||
# pytest
|
||||
execnet==2.1.1
|
||||
# via pytest-xdist
|
||||
h11==0.16.0
|
||||
# via httpcore
|
||||
httpcore==1.0.9
|
||||
# via httpx
|
||||
httpx==0.28.1
|
||||
# via
|
||||
# respx
|
||||
# tinker
|
||||
idna==3.10
|
||||
# via
|
||||
# anyio
|
||||
# httpx
|
||||
importlib-metadata==8.5.0 ; python_full_version < '3.9'
|
||||
importlib-metadata==8.6.1 ; python_full_version >= '3.9'
|
||||
iniconfig==2.0.0
|
||||
# via pytest
|
||||
markdown-it-py==3.0.0
|
||||
# via rich
|
||||
mdurl==0.1.2
|
||||
# via markdown-it-py
|
||||
mypy==1.14.1
|
||||
mypy-extensions==1.0.0
|
||||
# via mypy
|
||||
nest-asyncio==1.6.0
|
||||
nodeenv==1.9.1
|
||||
# via pyright
|
||||
packaging==24.2
|
||||
# via pytest
|
||||
pluggy==1.5.0
|
||||
# via pytest
|
||||
pydantic==2.10.3
|
||||
# via tinker
|
||||
pydantic-core==2.27.1
|
||||
# via pydantic
|
||||
pygments==2.19.1
|
||||
# via rich
|
||||
pyright==1.1.399
|
||||
pytest==8.3.3
|
||||
# via
|
||||
# pytest-asyncio
|
||||
# pytest-xdist
|
||||
pytest-asyncio==0.24.0
|
||||
pytest-xdist==3.6.1 ; python_full_version < '3.9'
|
||||
pytest-xdist==3.7.0 ; python_full_version >= '3.9'
|
||||
python-dateutil==2.9.0.post0
|
||||
# via time-machine
|
||||
pytz==2024.2 ; python_full_version < '3.9'
|
||||
# via dirty-equals
|
||||
respx==0.22.0
|
||||
rich==13.9.4
|
||||
ruff==0.9.4
|
||||
six==1.17.0
|
||||
# via python-dateutil
|
||||
sniffio==1.3.1
|
||||
# via
|
||||
# anyio
|
||||
# tinker
|
||||
time-machine==2.15.0 ; python_full_version < '3.9'
|
||||
time-machine==2.16.0 ; python_full_version >= '3.9'
|
||||
tomli==2.2.1 ; python_full_version < '3.11'
|
||||
# via
|
||||
# mypy
|
||||
# pytest
|
||||
typing-extensions==4.12.2
|
||||
# via
|
||||
# annotated-types
|
||||
# anyio
|
||||
# mypy
|
||||
# pydantic
|
||||
# pydantic-core
|
||||
# pyright
|
||||
# rich
|
||||
# tinker
|
||||
zipp==3.20.2 ; python_full_version < '3.9'
|
||||
# via importlib-metadata
|
||||
zipp==3.21.0 ; python_full_version >= '3.9'
|
||||
# via importlib-metadata
|
||||
22
scripts/bootstrap
Executable file
22
scripts/bootstrap
Executable file
|
|
@ -0,0 +1,22 @@
|
|||
#!/usr/bin/env bash
|
||||
|
||||
set -e
|
||||
|
||||
cd "$(dirname "$0")/.."
|
||||
|
||||
if [ -f "Brewfile" ] && [ "$(uname -s)" = "Darwin" ] && [ "$SKIP_BREW" != "1" ]; then
|
||||
brew bundle check >/dev/null 2>&1 || {
|
||||
echo "==> Installing Homebrew dependencies…"
|
||||
brew bundle
|
||||
}
|
||||
fi
|
||||
|
||||
echo "==> Installing Python…"
|
||||
uv python install
|
||||
|
||||
echo "==> Installing Python dependencies…"
|
||||
uv sync --all-extras
|
||||
|
||||
echo "==> Exporting Python dependencies…"
|
||||
# note: `--no-hashes` is required because of https://github.com/pypa/pip/issues/4995
|
||||
uv export -o requirements-dev.lock --no-hashes
|
||||
14
scripts/format
Executable file
14
scripts/format
Executable file
|
|
@ -0,0 +1,14 @@
|
|||
#!/usr/bin/env bash
|
||||
|
||||
set -e
|
||||
|
||||
cd "$(dirname "$0")/.."
|
||||
|
||||
echo "==> Running ruff"
|
||||
uv run ruff format
|
||||
uv run ruff check --fix .
|
||||
# run formatting again to fix any inconsistencies when imports are stripped
|
||||
uv run ruff format
|
||||
|
||||
echo "==> Formatting docs"
|
||||
uv run python scripts/utils/ruffen-docs.py README.md api.md
|
||||
17
scripts/lint
Executable file
17
scripts/lint
Executable file
|
|
@ -0,0 +1,17 @@
|
|||
#!/usr/bin/env bash
|
||||
|
||||
set -e
|
||||
|
||||
cd "$(dirname "$0")/.."
|
||||
|
||||
echo "==> Running ruff"
|
||||
uv run ruff check .
|
||||
|
||||
echo "==> Running pyright"
|
||||
uv run pyright
|
||||
|
||||
echo "==> Running mypy"
|
||||
uv run mypy .
|
||||
|
||||
echo "==> Making sure it imports"
|
||||
uv run python -c 'import tinker'
|
||||
41
scripts/mock
Executable file
41
scripts/mock
Executable file
|
|
@ -0,0 +1,41 @@
|
|||
#!/usr/bin/env bash
|
||||
|
||||
set -e
|
||||
|
||||
cd "$(dirname "$0")/.."
|
||||
|
||||
if [[ -n "$1" && "$1" != '--'* ]]; then
|
||||
URL="$1"
|
||||
shift
|
||||
else
|
||||
URL="$(grep 'openapi_spec_url' .stats.yml | cut -d' ' -f2)"
|
||||
fi
|
||||
|
||||
# Check if the URL is empty
|
||||
if [ -z "$URL" ]; then
|
||||
echo "Error: No OpenAPI spec path/url provided or found in .stats.yml"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "==> Starting mock server with URL ${URL}"
|
||||
|
||||
# Run prism mock on the given spec
|
||||
if [ "$1" == "--daemon" ]; then
|
||||
npm exec --package=@stainless-api/prism-cli@5.15.0 -- prism mock "$URL" &>.prism.log &
|
||||
|
||||
# Wait for server to come online
|
||||
echo -n "Waiting for server"
|
||||
while ! grep -q "✖ fatal\|Prism is listening" ".prism.log"; do
|
||||
echo -n "."
|
||||
sleep 0.1
|
||||
done
|
||||
|
||||
if grep -q "✖ fatal" ".prism.log"; then
|
||||
cat .prism.log
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo
|
||||
else
|
||||
npm exec --package=@stainless-api/prism-cli@5.15.0 -- prism mock "$URL"
|
||||
fi
|
||||
77
scripts/test
Executable file
77
scripts/test
Executable file
|
|
@ -0,0 +1,77 @@
|
|||
#!/usr/bin/env bash
|
||||
|
||||
set -e
|
||||
|
||||
cd "$(dirname "$0")/.."
|
||||
|
||||
RED='\033[0;31m'
|
||||
GREEN='\033[0;32m'
|
||||
YELLOW='\033[0;33m'
|
||||
NC='\033[0m' # No Color
|
||||
|
||||
function prism_is_running() {
|
||||
curl --silent "http://localhost:4010" >/dev/null 2>&1
|
||||
}
|
||||
|
||||
kill_server_on_port() {
|
||||
pids=$(lsof -t -i tcp:"$1" || echo "")
|
||||
if [ "$pids" != "" ]; then
|
||||
kill "$pids"
|
||||
echo "Stopped $pids."
|
||||
fi
|
||||
}
|
||||
|
||||
function is_overriding_api_base_url() {
|
||||
[ -n "$TEST_API_BASE_URL" ]
|
||||
}
|
||||
|
||||
if ! is_overriding_api_base_url && ! prism_is_running; then
|
||||
# When we exit this script, make sure to kill the background mock server process
|
||||
trap 'kill_server_on_port 4010' EXIT
|
||||
|
||||
# Start the dev server
|
||||
./scripts/mock --daemon
|
||||
fi
|
||||
|
||||
if is_overriding_api_base_url; then
|
||||
echo -e "${GREEN}✔ Running tests against ${TEST_API_BASE_URL}${NC}"
|
||||
echo
|
||||
elif ! prism_is_running; then
|
||||
echo -e "${RED}ERROR:${NC} The test suite will not run without a mock Prism server"
|
||||
echo -e "running against your OpenAPI spec."
|
||||
echo
|
||||
echo -e "To run the server, pass in the path or url of your OpenAPI"
|
||||
echo -e "spec to the prism command:"
|
||||
echo
|
||||
echo -e " \$ ${YELLOW}npm exec --package=@stainless-api/prism-cli@5.15.0 -- prism mock path/to/your.openapi.yml${NC}"
|
||||
echo
|
||||
|
||||
exit 1
|
||||
else
|
||||
echo -e "${GREEN}✔ Mock prism server is running with your OpenAPI spec${NC}"
|
||||
echo
|
||||
fi
|
||||
|
||||
export DEFER_PYDANTIC_BUILD=false
|
||||
|
||||
function run_tests() {
|
||||
echo "==> Running tests with Pydantic v2"
|
||||
uv run --all-extras --all-groups pytest "$@"
|
||||
|
||||
echo "==> Running tests with Pydantic v1"
|
||||
uv pip install 'pydantic<2'
|
||||
uv run --all-extras --all-groups pytest "$@"
|
||||
}
|
||||
|
||||
# If UV_PYTHON is already set in the environment, just run the command once
|
||||
if [[ -n "$UV_PYTHON" ]]; then
|
||||
run_tests "$@"
|
||||
else
|
||||
# If UV_PYTHON is not set, run the command for min and max versions
|
||||
|
||||
echo "==> Running tests for Python 3.9"
|
||||
UV_PYTHON=3.9 run_tests "$@"
|
||||
|
||||
echo "==> Running tests for Python 3.13"
|
||||
UV_PYTHON=3.13 run_tests "$@"
|
||||
fi
|
||||
167
scripts/utils/ruffen-docs.py
Normal file
167
scripts/utils/ruffen-docs.py
Normal file
|
|
@ -0,0 +1,167 @@
|
|||
# fork of https://github.com/asottile/blacken-docs adapted for ruff
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
import sys
|
||||
import argparse
|
||||
import textwrap
|
||||
import contextlib
|
||||
import subprocess
|
||||
from typing import Match, Optional, Sequence, Generator, NamedTuple, cast
|
||||
|
||||
MD_RE = re.compile(
|
||||
r"(?P<before>^(?P<indent> *)```\s*python\n)" r"(?P<code>.*?)" r"(?P<after>^(?P=indent)```\s*$)",
|
||||
re.DOTALL | re.MULTILINE,
|
||||
)
|
||||
MD_PYCON_RE = re.compile(
|
||||
r"(?P<before>^(?P<indent> *)```\s*pycon\n)" r"(?P<code>.*?)" r"(?P<after>^(?P=indent)```.*$)",
|
||||
re.DOTALL | re.MULTILINE,
|
||||
)
|
||||
PYCON_PREFIX = ">>> "
|
||||
PYCON_CONTINUATION_PREFIX = "..."
|
||||
PYCON_CONTINUATION_RE = re.compile(
|
||||
rf"^{re.escape(PYCON_CONTINUATION_PREFIX)}( |$)",
|
||||
)
|
||||
DEFAULT_LINE_LENGTH = 100
|
||||
|
||||
|
||||
class CodeBlockError(NamedTuple):
|
||||
offset: int
|
||||
exc: Exception
|
||||
|
||||
|
||||
def format_str(
|
||||
src: str,
|
||||
) -> tuple[str, Sequence[CodeBlockError]]:
|
||||
errors: list[CodeBlockError] = []
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _collect_error(match: Match[str]) -> Generator[None, None, None]:
|
||||
try:
|
||||
yield
|
||||
except Exception as e:
|
||||
errors.append(CodeBlockError(match.start(), e))
|
||||
|
||||
def _md_match(match: Match[str]) -> str:
|
||||
code = textwrap.dedent(match["code"])
|
||||
with _collect_error(match):
|
||||
code = format_code_block(code)
|
||||
code = textwrap.indent(code, match["indent"])
|
||||
return f"{match['before']}{code}{match['after']}"
|
||||
|
||||
def _pycon_match(match: Match[str]) -> str:
|
||||
code = ""
|
||||
fragment = cast(Optional[str], None)
|
||||
|
||||
def finish_fragment() -> None:
|
||||
nonlocal code
|
||||
nonlocal fragment
|
||||
|
||||
if fragment is not None:
|
||||
with _collect_error(match):
|
||||
fragment = format_code_block(fragment)
|
||||
fragment_lines = fragment.splitlines()
|
||||
code += f"{PYCON_PREFIX}{fragment_lines[0]}\n"
|
||||
for line in fragment_lines[1:]:
|
||||
# Skip blank lines to handle Black adding a blank above
|
||||
# functions within blocks. A blank line would end the REPL
|
||||
# continuation prompt.
|
||||
#
|
||||
# >>> if True:
|
||||
# ... def f():
|
||||
# ... pass
|
||||
# ...
|
||||
if line:
|
||||
code += f"{PYCON_CONTINUATION_PREFIX} {line}\n"
|
||||
if fragment_lines[-1].startswith(" "):
|
||||
code += f"{PYCON_CONTINUATION_PREFIX}\n"
|
||||
fragment = None
|
||||
|
||||
indentation = None
|
||||
for line in match["code"].splitlines():
|
||||
orig_line, line = line, line.lstrip()
|
||||
if indentation is None and line:
|
||||
indentation = len(orig_line) - len(line)
|
||||
continuation_match = PYCON_CONTINUATION_RE.match(line)
|
||||
if continuation_match and fragment is not None:
|
||||
fragment += line[continuation_match.end() :] + "\n"
|
||||
else:
|
||||
finish_fragment()
|
||||
if line.startswith(PYCON_PREFIX):
|
||||
fragment = line[len(PYCON_PREFIX) :] + "\n"
|
||||
else:
|
||||
code += orig_line[indentation:] + "\n"
|
||||
finish_fragment()
|
||||
return code
|
||||
|
||||
def _md_pycon_match(match: Match[str]) -> str:
|
||||
code = _pycon_match(match)
|
||||
code = textwrap.indent(code, match["indent"])
|
||||
return f"{match['before']}{code}{match['after']}"
|
||||
|
||||
src = MD_RE.sub(_md_match, src)
|
||||
src = MD_PYCON_RE.sub(_md_pycon_match, src)
|
||||
return src, errors
|
||||
|
||||
|
||||
def format_code_block(code: str) -> str:
|
||||
return subprocess.check_output(
|
||||
[
|
||||
sys.executable,
|
||||
"-m",
|
||||
"ruff",
|
||||
"format",
|
||||
"--stdin-filename=script.py",
|
||||
f"--line-length={DEFAULT_LINE_LENGTH}",
|
||||
],
|
||||
encoding="utf-8",
|
||||
input=code,
|
||||
)
|
||||
|
||||
|
||||
def format_file(
|
||||
filename: str,
|
||||
skip_errors: bool,
|
||||
) -> int:
|
||||
with open(filename, encoding="UTF-8") as f:
|
||||
contents = f.read()
|
||||
new_contents, errors = format_str(contents)
|
||||
for error in errors:
|
||||
lineno = contents[: error.offset].count("\n") + 1
|
||||
print(f"{filename}:{lineno}: code block parse error {error.exc}")
|
||||
if errors and not skip_errors:
|
||||
return 1
|
||||
if contents != new_contents:
|
||||
print(f"{filename}: Rewriting...")
|
||||
with open(filename, "w", encoding="UTF-8") as f:
|
||||
f.write(new_contents)
|
||||
return 0
|
||||
else:
|
||||
return 0
|
||||
|
||||
|
||||
def main(argv: Sequence[str] | None = None) -> int:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"-l",
|
||||
"--line-length",
|
||||
type=int,
|
||||
default=DEFAULT_LINE_LENGTH,
|
||||
)
|
||||
parser.add_argument(
|
||||
"-S",
|
||||
"--skip-string-normalization",
|
||||
action="store_true",
|
||||
)
|
||||
parser.add_argument("-E", "--skip-errors", action="store_true")
|
||||
parser.add_argument("filenames", nargs="*")
|
||||
args = parser.parse_args(argv)
|
||||
|
||||
retv = 0
|
||||
for filename in args.filenames:
|
||||
retv |= format_file(filename, skip_errors=args.skip_errors)
|
||||
return retv
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
27
scripts/utils/upload-artifact.sh
Executable file
27
scripts/utils/upload-artifact.sh
Executable file
|
|
@ -0,0 +1,27 @@
|
|||
#!/usr/bin/env bash
|
||||
set -exuo pipefail
|
||||
|
||||
FILENAME=$(basename dist/*.whl)
|
||||
|
||||
RESPONSE=$(curl -X POST "$URL?filename=$FILENAME" \
|
||||
-H "Authorization: Bearer $AUTH" \
|
||||
-H "Content-Type: application/json")
|
||||
|
||||
SIGNED_URL=$(echo "$RESPONSE" | jq -r '.url')
|
||||
|
||||
if [[ "$SIGNED_URL" == "null" ]]; then
|
||||
echo -e "\033[31mFailed to get signed URL.\033[0m"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
UPLOAD_RESPONSE=$(curl -v -X PUT \
|
||||
-H "Content-Type: binary/octet-stream" \
|
||||
--data-binary "@dist/$FILENAME" "$SIGNED_URL" 2>&1)
|
||||
|
||||
if echo "$UPLOAD_RESPONSE" | grep -q "HTTP/[0-9.]* 200"; then
|
||||
echo -e "\033[32mUploaded build to Stainless storage.\033[0m"
|
||||
echo -e "\033[32mInstallation: pip install 'https://pkg.stainless.com/s/tinker-python/$SHA/$FILENAME'\033[0m"
|
||||
else
|
||||
echo -e "\033[31mFailed to upload artifact.\033[0m"
|
||||
exit 1
|
||||
fi
|
||||
132
src/tinker/__init__.py
Normal file
132
src/tinker/__init__.py
Normal file
|
|
@ -0,0 +1,132 @@
|
|||
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
||||
|
||||
import typing as _t
|
||||
|
||||
from . import types
|
||||
from ._types import NOT_GIVEN, Omit, NoneType, NotGiven, Transport, ProxiesTypes
|
||||
from ._utils import file_from_path
|
||||
from ._client import Timeout, Transport, RequestOptions
|
||||
from ._models import BaseModel
|
||||
from ._version import __title__, __version__
|
||||
from ._response import APIResponse as APIResponse, AsyncAPIResponse as AsyncAPIResponse
|
||||
from ._constants import DEFAULT_TIMEOUT, DEFAULT_MAX_RETRIES, DEFAULT_CONNECTION_LIMITS
|
||||
from ._exceptions import (
|
||||
APIError,
|
||||
TinkerError,
|
||||
ConflictError,
|
||||
NotFoundError,
|
||||
APIStatusError,
|
||||
RateLimitError,
|
||||
APITimeoutError,
|
||||
BadRequestError,
|
||||
APIConnectionError,
|
||||
AuthenticationError,
|
||||
InternalServerError,
|
||||
PermissionDeniedError,
|
||||
UnprocessableEntityError,
|
||||
APIResponseValidationError,
|
||||
)
|
||||
from ._base_client import DefaultHttpxClient, DefaultAioHttpClient, DefaultAsyncHttpxClient
|
||||
from ._utils._logs import setup_logging as _setup_logging
|
||||
from .lib.public_interfaces import TrainingClient, ServiceClient, SamplingClient, APIFuture
|
||||
|
||||
# Import commonly used types for easier access
|
||||
from .types import (
|
||||
AdamParams,
|
||||
Checkpoint,
|
||||
CheckpointType,
|
||||
Datum,
|
||||
EncodedTextChunk,
|
||||
ForwardBackwardOutput,
|
||||
LoraConfig,
|
||||
ModelID,
|
||||
ModelInput,
|
||||
ModelInputChunk,
|
||||
OptimStepRequest,
|
||||
OptimStepResponse,
|
||||
ParsedCheckpointTinkerPath,
|
||||
SampledSequence,
|
||||
SampleRequest,
|
||||
SampleResponse,
|
||||
SamplingParams,
|
||||
StopReason,
|
||||
TensorData,
|
||||
TensorDtype,
|
||||
TrainingRun,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Core clients
|
||||
"TrainingClient",
|
||||
"ServiceClient",
|
||||
"SamplingClient",
|
||||
"APIFuture",
|
||||
|
||||
# Commonly used types
|
||||
"AdamParams",
|
||||
"Checkpoint",
|
||||
"CheckpointType",
|
||||
"Datum",
|
||||
"EncodedTextChunk",
|
||||
"ForwardBackwardOutput",
|
||||
"LoraConfig",
|
||||
"ModelID",
|
||||
"ModelInput",
|
||||
"ModelInputChunk",
|
||||
"OptimStepRequest",
|
||||
"OptimStepResponse",
|
||||
"ParsedCheckpointTinkerPath",
|
||||
"SampledSequence",
|
||||
"SampleRequest",
|
||||
"SampleResponse",
|
||||
"SamplingParams",
|
||||
"StopReason",
|
||||
"TensorData",
|
||||
"TensorDtype",
|
||||
"TrainingRun",
|
||||
|
||||
# Client configuration
|
||||
"Timeout",
|
||||
"RequestOptions",
|
||||
|
||||
# Exception types
|
||||
"TinkerError",
|
||||
"APIError",
|
||||
"APIStatusError",
|
||||
"APITimeoutError",
|
||||
"APIConnectionError",
|
||||
"APIResponseValidationError",
|
||||
"BadRequestError",
|
||||
"AuthenticationError",
|
||||
"PermissionDeniedError",
|
||||
"NotFoundError",
|
||||
"ConflictError",
|
||||
"UnprocessableEntityError",
|
||||
"RateLimitError",
|
||||
"InternalServerError",
|
||||
|
||||
# Keep types module for advanced use
|
||||
"types",
|
||||
|
||||
# Version info
|
||||
"__version__",
|
||||
"__title__",
|
||||
]
|
||||
|
||||
if not _t.TYPE_CHECKING:
|
||||
from ._utils._resources_proxy import resources as resources
|
||||
|
||||
_setup_logging()
|
||||
|
||||
# Update the __module__ attribute for exported symbols so that
|
||||
# error messages point to this module instead of the module
|
||||
# it was originally defined in, e.g.
|
||||
# tinker._exceptions.NotFoundError -> tinker.NotFoundError
|
||||
__locals = locals()
|
||||
for __name in __all__:
|
||||
if not __name.startswith("__"):
|
||||
try:
|
||||
__locals[__name].__module__ = "tinker"
|
||||
except (TypeError, AttributeError):
|
||||
# Some of our exported symbols are builtins which we can't set attributes for.
|
||||
pass
|
||||
1997
src/tinker/_base_client.py
Normal file
1997
src/tinker/_base_client.py
Normal file
File diff suppressed because it is too large
Load diff
668
src/tinker/_client.py
Normal file
668
src/tinker/_client.py
Normal file
|
|
@ -0,0 +1,668 @@
|
|||
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, Union, Mapping
|
||||
from typing_extensions import Self, override
|
||||
|
||||
import httpx
|
||||
|
||||
from . import _exceptions
|
||||
from ._qs import Querystring
|
||||
from ._types import (
|
||||
NOT_GIVEN,
|
||||
Omit,
|
||||
Timeout,
|
||||
NotGiven,
|
||||
Transport,
|
||||
ProxiesTypes,
|
||||
RequestOptions,
|
||||
)
|
||||
from ._utils import is_given, get_async_library
|
||||
from ._compat import cached_property
|
||||
from ._version import __version__
|
||||
from ._streaming import Stream as Stream, AsyncStream as AsyncStream
|
||||
from ._exceptions import TinkerError, APIStatusError
|
||||
from ._base_client import (
|
||||
DEFAULT_MAX_RETRIES,
|
||||
SyncAPIClient,
|
||||
AsyncAPIClient,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .resources import models, futures, service, weights, sampling, training, telemetry
|
||||
from .resources.models import ModelsResource, AsyncModelsResource
|
||||
from .resources.futures import FuturesResource, AsyncFuturesResource
|
||||
from .resources.service import ServiceResource, AsyncServiceResource
|
||||
from .resources.weights import WeightsResource, AsyncWeightsResource
|
||||
from .resources.sampling import SamplingResource, AsyncSamplingResource
|
||||
from .resources.training import TrainingResource, AsyncTrainingResource
|
||||
from .resources.telemetry import TelemetryResource, AsyncTelemetryResource
|
||||
|
||||
__all__ = ["Timeout", "Transport", "ProxiesTypes", "RequestOptions", "Tinker", "AsyncTinker", "Client", "AsyncClient"]
|
||||
|
||||
|
||||
class Tinker(SyncAPIClient):
|
||||
# client options
|
||||
api_key: str
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: str | None = None,
|
||||
base_url: str | httpx.URL | None = None,
|
||||
timeout: Union[float, Timeout, None, NotGiven] = NOT_GIVEN,
|
||||
max_retries: int = DEFAULT_MAX_RETRIES,
|
||||
default_headers: Mapping[str, str] | None = None,
|
||||
default_query: Mapping[str, object] | None = None,
|
||||
# Configure a custom httpx client.
|
||||
# We provide a `DefaultHttpxClient` class that you can pass to retain the default values we use for `limits`, `timeout` & `follow_redirects`.
|
||||
# See the [httpx documentation](https://www.python-httpx.org/api/#client) for more details.
|
||||
http_client: httpx.Client | None = None,
|
||||
# Enable or disable schema validation for data returned by the API.
|
||||
# When enabled an error APIResponseValidationError is raised
|
||||
# if the API responds with invalid data for the expected schema.
|
||||
#
|
||||
# This parameter may be removed or changed in the future.
|
||||
# If you rely on this feature, please open a GitHub issue
|
||||
# outlining your use-case to help us decide if it should be
|
||||
# part of our public interface in the future.
|
||||
_strict_response_validation: bool = False,
|
||||
) -> None:
|
||||
"""Construct a new synchronous Tinker client instance.
|
||||
|
||||
This automatically infers the `api_key` argument from the `TINKER_API_KEY` environment variable if it is not provided.
|
||||
"""
|
||||
if api_key is None:
|
||||
api_key = os.environ.get("TINKER_API_KEY")
|
||||
if api_key is None:
|
||||
raise TinkerError(
|
||||
"The api_key client option must be set either by passing api_key to the client or by setting the TINKER_API_KEY environment variable"
|
||||
)
|
||||
self.api_key = api_key
|
||||
|
||||
if base_url is None:
|
||||
base_url = os.environ.get("TINKER_BASE_URL")
|
||||
if base_url is None:
|
||||
base_url = f"https://tinker.thinkingmachines.dev/services/tinker-prod"
|
||||
|
||||
super().__init__(
|
||||
version=__version__,
|
||||
base_url=base_url,
|
||||
max_retries=max_retries,
|
||||
timeout=timeout,
|
||||
http_client=http_client,
|
||||
custom_headers=default_headers,
|
||||
custom_query=default_query,
|
||||
_strict_response_validation=_strict_response_validation,
|
||||
)
|
||||
|
||||
self._idempotency_header = "X-Idempotency-Key"
|
||||
|
||||
@cached_property
|
||||
def service(self) -> ServiceResource:
|
||||
from .resources.service import ServiceResource
|
||||
|
||||
return ServiceResource(self)
|
||||
|
||||
@cached_property
|
||||
def training(self) -> TrainingResource:
|
||||
from .resources.training import TrainingResource
|
||||
|
||||
return TrainingResource(self)
|
||||
|
||||
@cached_property
|
||||
def models(self) -> ModelsResource:
|
||||
from .resources.models import ModelsResource
|
||||
|
||||
return ModelsResource(self)
|
||||
|
||||
@cached_property
|
||||
def weights(self) -> WeightsResource:
|
||||
from .resources.weights import WeightsResource
|
||||
|
||||
return WeightsResource(self)
|
||||
|
||||
@cached_property
|
||||
def sampling(self) -> SamplingResource:
|
||||
from .resources.sampling import SamplingResource
|
||||
|
||||
return SamplingResource(self)
|
||||
|
||||
@cached_property
|
||||
def futures(self) -> FuturesResource:
|
||||
from .resources.futures import FuturesResource
|
||||
|
||||
return FuturesResource(self)
|
||||
|
||||
@cached_property
|
||||
def telemetry(self) -> TelemetryResource:
|
||||
from .resources.telemetry import TelemetryResource
|
||||
|
||||
return TelemetryResource(self)
|
||||
|
||||
@cached_property
|
||||
def with_raw_response(self) -> TinkerWithRawResponse:
|
||||
return TinkerWithRawResponse(self)
|
||||
|
||||
@cached_property
|
||||
def with_streaming_response(self) -> TinkerWithStreamedResponse:
|
||||
return TinkerWithStreamedResponse(self)
|
||||
|
||||
@property
|
||||
@override
|
||||
def qs(self) -> Querystring:
|
||||
return Querystring(array_format="comma")
|
||||
|
||||
@property
|
||||
@override
|
||||
def auth_headers(self) -> dict[str, str]:
|
||||
api_key = self.api_key
|
||||
return {"X-API-Key": api_key}
|
||||
|
||||
@property
|
||||
@override
|
||||
def default_headers(self) -> dict[str, str | Omit]:
|
||||
return {
|
||||
**super().default_headers,
|
||||
"X-Stainless-Async": "false",
|
||||
**self._custom_headers,
|
||||
}
|
||||
|
||||
def copy(
|
||||
self,
|
||||
*,
|
||||
api_key: str | None = None,
|
||||
base_url: str | httpx.URL | None = None,
|
||||
timeout: float | Timeout | None | NotGiven = NOT_GIVEN,
|
||||
http_client: httpx.Client | None = None,
|
||||
max_retries: int | NotGiven = NOT_GIVEN,
|
||||
default_headers: Mapping[str, str] | None = None,
|
||||
set_default_headers: Mapping[str, str] | None = None,
|
||||
default_query: Mapping[str, object] | None = None,
|
||||
set_default_query: Mapping[str, object] | None = None,
|
||||
_extra_kwargs: Mapping[str, Any] = {},
|
||||
) -> Self:
|
||||
"""
|
||||
Create a new client instance re-using the same options given to the current client with optional overriding.
|
||||
"""
|
||||
if default_headers is not None and set_default_headers is not None:
|
||||
raise ValueError("The `default_headers` and `set_default_headers` arguments are mutually exclusive")
|
||||
|
||||
if default_query is not None and set_default_query is not None:
|
||||
raise ValueError("The `default_query` and `set_default_query` arguments are mutually exclusive")
|
||||
|
||||
headers = self._custom_headers
|
||||
if default_headers is not None:
|
||||
headers = {**headers, **default_headers}
|
||||
elif set_default_headers is not None:
|
||||
headers = set_default_headers
|
||||
|
||||
params = self._custom_query
|
||||
if default_query is not None:
|
||||
params = {**params, **default_query}
|
||||
elif set_default_query is not None:
|
||||
params = set_default_query
|
||||
|
||||
http_client = http_client or self._client
|
||||
return self.__class__(
|
||||
api_key=api_key or self.api_key,
|
||||
base_url=base_url or self.base_url,
|
||||
timeout=self.timeout if isinstance(timeout, NotGiven) else timeout,
|
||||
http_client=http_client,
|
||||
max_retries=max_retries if is_given(max_retries) else self.max_retries,
|
||||
default_headers=headers,
|
||||
default_query=params,
|
||||
**_extra_kwargs,
|
||||
)
|
||||
|
||||
# Alias for `copy` for nicer inline usage, e.g.
|
||||
# client.with_options(timeout=10).foo.create(...)
|
||||
with_options = copy
|
||||
|
||||
@override
|
||||
def _make_status_error(
|
||||
self,
|
||||
err_msg: str,
|
||||
*,
|
||||
body: object,
|
||||
response: httpx.Response,
|
||||
) -> APIStatusError:
|
||||
if response.status_code == 400:
|
||||
return _exceptions.BadRequestError(err_msg, response=response, body=body)
|
||||
|
||||
if response.status_code == 401:
|
||||
return _exceptions.AuthenticationError(err_msg, response=response, body=body)
|
||||
|
||||
if response.status_code == 403:
|
||||
return _exceptions.PermissionDeniedError(err_msg, response=response, body=body)
|
||||
|
||||
if response.status_code == 404:
|
||||
return _exceptions.NotFoundError(err_msg, response=response, body=body)
|
||||
|
||||
if response.status_code == 409:
|
||||
return _exceptions.ConflictError(err_msg, response=response, body=body)
|
||||
|
||||
if response.status_code == 422:
|
||||
return _exceptions.UnprocessableEntityError(err_msg, response=response, body=body)
|
||||
|
||||
if response.status_code == 429:
|
||||
return _exceptions.RateLimitError(err_msg, response=response, body=body)
|
||||
|
||||
if response.status_code >= 500:
|
||||
return _exceptions.InternalServerError(err_msg, response=response, body=body)
|
||||
return APIStatusError(err_msg, response=response, body=body)
|
||||
|
||||
|
||||
class AsyncTinker(AsyncAPIClient):
|
||||
# client options
|
||||
api_key: str
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: str | None = None,
|
||||
base_url: str | httpx.URL | None = None,
|
||||
timeout: Union[float, Timeout, None, NotGiven] = NOT_GIVEN,
|
||||
max_retries: int = DEFAULT_MAX_RETRIES,
|
||||
default_headers: Mapping[str, str] | None = None,
|
||||
default_query: Mapping[str, object] | None = None,
|
||||
# Configure a custom httpx client.
|
||||
# We provide a `DefaultAsyncHttpxClient` class that you can pass to retain the default values we use for `limits`, `timeout` & `follow_redirects`.
|
||||
# See the [httpx documentation](https://www.python-httpx.org/api/#asyncclient) for more details.
|
||||
http_client: httpx.AsyncClient | None = None,
|
||||
# Enable or disable schema validation for data returned by the API.
|
||||
# When enabled an error APIResponseValidationError is raised
|
||||
# if the API responds with invalid data for the expected schema.
|
||||
#
|
||||
# This parameter may be removed or changed in the future.
|
||||
# If you rely on this feature, please open a GitHub issue
|
||||
# outlining your use-case to help us decide if it should be
|
||||
# part of our public interface in the future.
|
||||
_strict_response_validation: bool = False,
|
||||
) -> None:
|
||||
"""Construct a new async AsyncTinker client instance.
|
||||
|
||||
This automatically infers the `api_key` argument from the `TINKER_API_KEY` environment variable if it is not provided.
|
||||
"""
|
||||
if api_key is None:
|
||||
api_key = os.environ.get("TINKER_API_KEY")
|
||||
if api_key is None:
|
||||
raise TinkerError(
|
||||
"The api_key client option must be set either by passing api_key to the client or by setting the TINKER_API_KEY environment variable"
|
||||
)
|
||||
self.api_key = api_key
|
||||
|
||||
if base_url is None:
|
||||
base_url = os.environ.get("TINKER_BASE_URL")
|
||||
if base_url is None:
|
||||
base_url = f"https://tinker.thinkingmachines.dev/services/tinker-prod"
|
||||
|
||||
super().__init__(
|
||||
version=__version__,
|
||||
base_url=base_url,
|
||||
max_retries=max_retries,
|
||||
timeout=timeout,
|
||||
http_client=http_client,
|
||||
custom_headers=default_headers,
|
||||
custom_query=default_query,
|
||||
_strict_response_validation=_strict_response_validation,
|
||||
)
|
||||
|
||||
self._idempotency_header = "X-Idempotency-Key"
|
||||
|
||||
@cached_property
|
||||
def service(self) -> AsyncServiceResource:
|
||||
from .resources.service import AsyncServiceResource
|
||||
|
||||
return AsyncServiceResource(self)
|
||||
|
||||
@cached_property
|
||||
def training(self) -> AsyncTrainingResource:
|
||||
from .resources.training import AsyncTrainingResource
|
||||
|
||||
return AsyncTrainingResource(self)
|
||||
|
||||
@cached_property
|
||||
def models(self) -> AsyncModelsResource:
|
||||
from .resources.models import AsyncModelsResource
|
||||
|
||||
return AsyncModelsResource(self)
|
||||
|
||||
@cached_property
|
||||
def weights(self) -> AsyncWeightsResource:
|
||||
from .resources.weights import AsyncWeightsResource
|
||||
|
||||
return AsyncWeightsResource(self)
|
||||
|
||||
@cached_property
|
||||
def sampling(self) -> AsyncSamplingResource:
|
||||
from .resources.sampling import AsyncSamplingResource
|
||||
|
||||
return AsyncSamplingResource(self)
|
||||
|
||||
@cached_property
|
||||
def futures(self) -> AsyncFuturesResource:
|
||||
from .resources.futures import AsyncFuturesResource
|
||||
|
||||
return AsyncFuturesResource(self)
|
||||
|
||||
@cached_property
|
||||
def telemetry(self) -> AsyncTelemetryResource:
|
||||
from .resources.telemetry import AsyncTelemetryResource
|
||||
|
||||
return AsyncTelemetryResource(self)
|
||||
|
||||
@cached_property
|
||||
def with_raw_response(self) -> AsyncTinkerWithRawResponse:
|
||||
return AsyncTinkerWithRawResponse(self)
|
||||
|
||||
@cached_property
|
||||
def with_streaming_response(self) -> AsyncTinkerWithStreamedResponse:
|
||||
return AsyncTinkerWithStreamedResponse(self)
|
||||
|
||||
@property
|
||||
@override
|
||||
def qs(self) -> Querystring:
|
||||
return Querystring(array_format="comma")
|
||||
|
||||
@property
|
||||
@override
|
||||
def auth_headers(self) -> dict[str, str]:
|
||||
api_key = self.api_key
|
||||
return {"X-API-Key": api_key}
|
||||
|
||||
@property
|
||||
@override
|
||||
def default_headers(self) -> dict[str, str | Omit]:
|
||||
return {
|
||||
**super().default_headers,
|
||||
"X-Stainless-Async": f"async:{get_async_library()}",
|
||||
**self._custom_headers,
|
||||
}
|
||||
|
||||
def copy(
|
||||
self,
|
||||
*,
|
||||
api_key: str | None = None,
|
||||
base_url: str | httpx.URL | None = None,
|
||||
timeout: float | Timeout | None | NotGiven = NOT_GIVEN,
|
||||
http_client: httpx.AsyncClient | None = None,
|
||||
max_retries: int | NotGiven = NOT_GIVEN,
|
||||
default_headers: Mapping[str, str] | None = None,
|
||||
set_default_headers: Mapping[str, str] | None = None,
|
||||
default_query: Mapping[str, object] | None = None,
|
||||
set_default_query: Mapping[str, object] | None = None,
|
||||
_extra_kwargs: Mapping[str, Any] = {},
|
||||
) -> Self:
|
||||
"""
|
||||
Create a new client instance re-using the same options given to the current client with optional overriding.
|
||||
"""
|
||||
if default_headers is not None and set_default_headers is not None:
|
||||
raise ValueError("The `default_headers` and `set_default_headers` arguments are mutually exclusive")
|
||||
|
||||
if default_query is not None and set_default_query is not None:
|
||||
raise ValueError("The `default_query` and `set_default_query` arguments are mutually exclusive")
|
||||
|
||||
headers = self._custom_headers
|
||||
if default_headers is not None:
|
||||
headers = {**headers, **default_headers}
|
||||
elif set_default_headers is not None:
|
||||
headers = set_default_headers
|
||||
|
||||
params = self._custom_query
|
||||
if default_query is not None:
|
||||
params = {**params, **default_query}
|
||||
elif set_default_query is not None:
|
||||
params = set_default_query
|
||||
|
||||
http_client = http_client or self._client
|
||||
return self.__class__(
|
||||
api_key=api_key or self.api_key,
|
||||
base_url=base_url or self.base_url,
|
||||
timeout=self.timeout if isinstance(timeout, NotGiven) else timeout,
|
||||
http_client=http_client,
|
||||
max_retries=max_retries if is_given(max_retries) else self.max_retries,
|
||||
default_headers=headers,
|
||||
default_query=params,
|
||||
**_extra_kwargs,
|
||||
)
|
||||
|
||||
# Alias for `copy` for nicer inline usage, e.g.
|
||||
# client.with_options(timeout=10).foo.create(...)
|
||||
with_options = copy
|
||||
|
||||
@override
|
||||
def _make_status_error(
|
||||
self,
|
||||
err_msg: str,
|
||||
*,
|
||||
body: object,
|
||||
response: httpx.Response,
|
||||
) -> APIStatusError:
|
||||
if response.status_code == 400:
|
||||
return _exceptions.BadRequestError(err_msg, response=response, body=body)
|
||||
|
||||
if response.status_code == 401:
|
||||
return _exceptions.AuthenticationError(err_msg, response=response, body=body)
|
||||
|
||||
if response.status_code == 403:
|
||||
return _exceptions.PermissionDeniedError(err_msg, response=response, body=body)
|
||||
|
||||
if response.status_code == 404:
|
||||
return _exceptions.NotFoundError(err_msg, response=response, body=body)
|
||||
|
||||
if response.status_code == 409:
|
||||
return _exceptions.ConflictError(err_msg, response=response, body=body)
|
||||
|
||||
if response.status_code == 422:
|
||||
return _exceptions.UnprocessableEntityError(err_msg, response=response, body=body)
|
||||
|
||||
if response.status_code == 429:
|
||||
return _exceptions.RateLimitError(err_msg, response=response, body=body)
|
||||
|
||||
if response.status_code >= 500:
|
||||
return _exceptions.InternalServerError(err_msg, response=response, body=body)
|
||||
return APIStatusError(err_msg, response=response, body=body)
|
||||
|
||||
|
||||
class TinkerWithRawResponse:
|
||||
_client: Tinker
|
||||
|
||||
def __init__(self, client: Tinker) -> None:
|
||||
self._client = client
|
||||
|
||||
@cached_property
|
||||
def service(self) -> service.ServiceResourceWithRawResponse:
|
||||
from .resources.service import ServiceResourceWithRawResponse
|
||||
|
||||
return ServiceResourceWithRawResponse(self._client.service)
|
||||
|
||||
@cached_property
|
||||
def training(self) -> training.TrainingResourceWithRawResponse:
|
||||
from .resources.training import TrainingResourceWithRawResponse
|
||||
|
||||
return TrainingResourceWithRawResponse(self._client.training)
|
||||
|
||||
@cached_property
|
||||
def models(self) -> models.ModelsResourceWithRawResponse:
|
||||
from .resources.models import ModelsResourceWithRawResponse
|
||||
|
||||
return ModelsResourceWithRawResponse(self._client.models)
|
||||
|
||||
@cached_property
|
||||
def weights(self) -> weights.WeightsResourceWithRawResponse:
|
||||
from .resources.weights import WeightsResourceWithRawResponse
|
||||
|
||||
return WeightsResourceWithRawResponse(self._client.weights)
|
||||
|
||||
@cached_property
|
||||
def sampling(self) -> sampling.SamplingResourceWithRawResponse:
|
||||
from .resources.sampling import SamplingResourceWithRawResponse
|
||||
|
||||
return SamplingResourceWithRawResponse(self._client.sampling)
|
||||
|
||||
@cached_property
|
||||
def futures(self) -> futures.FuturesResourceWithRawResponse:
|
||||
from .resources.futures import FuturesResourceWithRawResponse
|
||||
|
||||
return FuturesResourceWithRawResponse(self._client.futures)
|
||||
|
||||
@cached_property
|
||||
def telemetry(self) -> telemetry.TelemetryResourceWithRawResponse:
|
||||
from .resources.telemetry import TelemetryResourceWithRawResponse
|
||||
|
||||
return TelemetryResourceWithRawResponse(self._client.telemetry)
|
||||
|
||||
|
||||
class AsyncTinkerWithRawResponse:
|
||||
_client: AsyncTinker
|
||||
|
||||
def __init__(self, client: AsyncTinker) -> None:
|
||||
self._client = client
|
||||
|
||||
@cached_property
|
||||
def service(self) -> service.AsyncServiceResourceWithRawResponse:
|
||||
from .resources.service import AsyncServiceResourceWithRawResponse
|
||||
|
||||
return AsyncServiceResourceWithRawResponse(self._client.service)
|
||||
|
||||
@cached_property
|
||||
def training(self) -> training.AsyncTrainingResourceWithRawResponse:
|
||||
from .resources.training import AsyncTrainingResourceWithRawResponse
|
||||
|
||||
return AsyncTrainingResourceWithRawResponse(self._client.training)
|
||||
|
||||
@cached_property
|
||||
def models(self) -> models.AsyncModelsResourceWithRawResponse:
|
||||
from .resources.models import AsyncModelsResourceWithRawResponse
|
||||
|
||||
return AsyncModelsResourceWithRawResponse(self._client.models)
|
||||
|
||||
@cached_property
|
||||
def weights(self) -> weights.AsyncWeightsResourceWithRawResponse:
|
||||
from .resources.weights import AsyncWeightsResourceWithRawResponse
|
||||
|
||||
return AsyncWeightsResourceWithRawResponse(self._client.weights)
|
||||
|
||||
@cached_property
|
||||
def sampling(self) -> sampling.AsyncSamplingResourceWithRawResponse:
|
||||
from .resources.sampling import AsyncSamplingResourceWithRawResponse
|
||||
|
||||
return AsyncSamplingResourceWithRawResponse(self._client.sampling)
|
||||
|
||||
@cached_property
|
||||
def futures(self) -> futures.AsyncFuturesResourceWithRawResponse:
|
||||
from .resources.futures import AsyncFuturesResourceWithRawResponse
|
||||
|
||||
return AsyncFuturesResourceWithRawResponse(self._client.futures)
|
||||
|
||||
@cached_property
|
||||
def telemetry(self) -> telemetry.AsyncTelemetryResourceWithRawResponse:
|
||||
from .resources.telemetry import AsyncTelemetryResourceWithRawResponse
|
||||
|
||||
return AsyncTelemetryResourceWithRawResponse(self._client.telemetry)
|
||||
|
||||
|
||||
class TinkerWithStreamedResponse:
|
||||
_client: Tinker
|
||||
|
||||
def __init__(self, client: Tinker) -> None:
|
||||
self._client = client
|
||||
|
||||
@cached_property
|
||||
def service(self) -> service.ServiceResourceWithStreamingResponse:
|
||||
from .resources.service import ServiceResourceWithStreamingResponse
|
||||
|
||||
return ServiceResourceWithStreamingResponse(self._client.service)
|
||||
|
||||
@cached_property
|
||||
def training(self) -> training.TrainingResourceWithStreamingResponse:
|
||||
from .resources.training import TrainingResourceWithStreamingResponse
|
||||
|
||||
return TrainingResourceWithStreamingResponse(self._client.training)
|
||||
|
||||
@cached_property
|
||||
def models(self) -> models.ModelsResourceWithStreamingResponse:
|
||||
from .resources.models import ModelsResourceWithStreamingResponse
|
||||
|
||||
return ModelsResourceWithStreamingResponse(self._client.models)
|
||||
|
||||
@cached_property
|
||||
def weights(self) -> weights.WeightsResourceWithStreamingResponse:
|
||||
from .resources.weights import WeightsResourceWithStreamingResponse
|
||||
|
||||
return WeightsResourceWithStreamingResponse(self._client.weights)
|
||||
|
||||
@cached_property
|
||||
def sampling(self) -> sampling.SamplingResourceWithStreamingResponse:
|
||||
from .resources.sampling import SamplingResourceWithStreamingResponse
|
||||
|
||||
return SamplingResourceWithStreamingResponse(self._client.sampling)
|
||||
|
||||
@cached_property
|
||||
def futures(self) -> futures.FuturesResourceWithStreamingResponse:
|
||||
from .resources.futures import FuturesResourceWithStreamingResponse
|
||||
|
||||
return FuturesResourceWithStreamingResponse(self._client.futures)
|
||||
|
||||
@cached_property
|
||||
def telemetry(self) -> telemetry.TelemetryResourceWithStreamingResponse:
|
||||
from .resources.telemetry import TelemetryResourceWithStreamingResponse
|
||||
|
||||
return TelemetryResourceWithStreamingResponse(self._client.telemetry)
|
||||
|
||||
|
||||
class AsyncTinkerWithStreamedResponse:
|
||||
_client: AsyncTinker
|
||||
|
||||
def __init__(self, client: AsyncTinker) -> None:
|
||||
self._client = client
|
||||
|
||||
@cached_property
|
||||
def service(self) -> service.AsyncServiceResourceWithStreamingResponse:
|
||||
from .resources.service import AsyncServiceResourceWithStreamingResponse
|
||||
|
||||
return AsyncServiceResourceWithStreamingResponse(self._client.service)
|
||||
|
||||
@cached_property
|
||||
def training(self) -> training.AsyncTrainingResourceWithStreamingResponse:
|
||||
from .resources.training import AsyncTrainingResourceWithStreamingResponse
|
||||
|
||||
return AsyncTrainingResourceWithStreamingResponse(self._client.training)
|
||||
|
||||
@cached_property
|
||||
def models(self) -> models.AsyncModelsResourceWithStreamingResponse:
|
||||
from .resources.models import AsyncModelsResourceWithStreamingResponse
|
||||
|
||||
return AsyncModelsResourceWithStreamingResponse(self._client.models)
|
||||
|
||||
@cached_property
|
||||
def weights(self) -> weights.AsyncWeightsResourceWithStreamingResponse:
|
||||
from .resources.weights import AsyncWeightsResourceWithStreamingResponse
|
||||
|
||||
return AsyncWeightsResourceWithStreamingResponse(self._client.weights)
|
||||
|
||||
@cached_property
|
||||
def sampling(self) -> sampling.AsyncSamplingResourceWithStreamingResponse:
|
||||
from .resources.sampling import AsyncSamplingResourceWithStreamingResponse
|
||||
|
||||
return AsyncSamplingResourceWithStreamingResponse(self._client.sampling)
|
||||
|
||||
@cached_property
|
||||
def futures(self) -> futures.AsyncFuturesResourceWithStreamingResponse:
|
||||
from .resources.futures import AsyncFuturesResourceWithStreamingResponse
|
||||
|
||||
return AsyncFuturesResourceWithStreamingResponse(self._client.futures)
|
||||
|
||||
@cached_property
|
||||
def telemetry(self) -> telemetry.AsyncTelemetryResourceWithStreamingResponse:
|
||||
from .resources.telemetry import AsyncTelemetryResourceWithStreamingResponse
|
||||
|
||||
return AsyncTelemetryResourceWithStreamingResponse(self._client.telemetry)
|
||||
|
||||
|
||||
Client = Tinker
|
||||
|
||||
AsyncClient = AsyncTinker
|
||||
219
src/tinker/_compat.py
Normal file
219
src/tinker/_compat.py
Normal file
|
|
@ -0,0 +1,219 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Union, Generic, TypeVar, Callable, cast, overload
|
||||
from datetime import date, datetime
|
||||
from typing_extensions import Self, Literal
|
||||
|
||||
import pydantic
|
||||
from pydantic.fields import FieldInfo
|
||||
|
||||
from ._types import IncEx, StrBytesIntFloat
|
||||
|
||||
_T = TypeVar("_T")
|
||||
_ModelT = TypeVar("_ModelT", bound=pydantic.BaseModel)
|
||||
|
||||
# --------------- Pydantic v2 compatibility ---------------
|
||||
|
||||
# Pyright incorrectly reports some of our functions as overriding a method when they don't
|
||||
# pyright: reportIncompatibleMethodOverride=false
|
||||
|
||||
PYDANTIC_V2 = pydantic.VERSION.startswith("2.")
|
||||
|
||||
# v1 re-exports
|
||||
if TYPE_CHECKING:
|
||||
|
||||
def parse_date(value: date | StrBytesIntFloat) -> date: # noqa: ARG001
|
||||
...
|
||||
|
||||
def parse_datetime(value: Union[datetime, StrBytesIntFloat]) -> datetime: # noqa: ARG001
|
||||
...
|
||||
|
||||
def get_args(t: type[Any]) -> tuple[Any, ...]: # noqa: ARG001
|
||||
...
|
||||
|
||||
def is_union(tp: type[Any] | None) -> bool: # noqa: ARG001
|
||||
...
|
||||
|
||||
def get_origin(t: type[Any]) -> type[Any] | None: # noqa: ARG001
|
||||
...
|
||||
|
||||
def is_literal_type(type_: type[Any]) -> bool: # noqa: ARG001
|
||||
...
|
||||
|
||||
def is_typeddict(type_: type[Any]) -> bool: # noqa: ARG001
|
||||
...
|
||||
|
||||
else:
|
||||
if PYDANTIC_V2:
|
||||
from pydantic.v1.typing import (
|
||||
get_args as get_args,
|
||||
is_union as is_union,
|
||||
get_origin as get_origin,
|
||||
is_typeddict as is_typeddict,
|
||||
is_literal_type as is_literal_type,
|
||||
)
|
||||
from pydantic.v1.datetime_parse import parse_date as parse_date, parse_datetime as parse_datetime
|
||||
else:
|
||||
from pydantic.typing import (
|
||||
get_args as get_args,
|
||||
is_union as is_union,
|
||||
get_origin as get_origin,
|
||||
is_typeddict as is_typeddict,
|
||||
is_literal_type as is_literal_type,
|
||||
)
|
||||
from pydantic.datetime_parse import parse_date as parse_date, parse_datetime as parse_datetime
|
||||
|
||||
|
||||
# refactored config
|
||||
if TYPE_CHECKING:
|
||||
from pydantic import ConfigDict as ConfigDict
|
||||
else:
|
||||
if PYDANTIC_V2:
|
||||
from pydantic import ConfigDict
|
||||
else:
|
||||
# TODO: provide an error message here?
|
||||
ConfigDict = None
|
||||
|
||||
|
||||
# renamed methods / properties
|
||||
def parse_obj(model: type[_ModelT], value: object) -> _ModelT:
|
||||
if PYDANTIC_V2:
|
||||
return model.model_validate(value)
|
||||
else:
|
||||
return cast(_ModelT, model.parse_obj(value)) # pyright: ignore[reportDeprecated, reportUnnecessaryCast]
|
||||
|
||||
|
||||
def field_is_required(field: FieldInfo) -> bool:
|
||||
if PYDANTIC_V2:
|
||||
return field.is_required()
|
||||
return field.required # type: ignore
|
||||
|
||||
|
||||
def field_get_default(field: FieldInfo) -> Any:
|
||||
value = field.get_default()
|
||||
if PYDANTIC_V2:
|
||||
from pydantic_core import PydanticUndefined
|
||||
|
||||
if value == PydanticUndefined:
|
||||
return None
|
||||
return value
|
||||
return value
|
||||
|
||||
|
||||
def field_outer_type(field: FieldInfo) -> Any:
|
||||
if PYDANTIC_V2:
|
||||
return field.annotation
|
||||
return field.outer_type_ # type: ignore
|
||||
|
||||
|
||||
def get_model_config(model: type[pydantic.BaseModel]) -> Any:
|
||||
if PYDANTIC_V2:
|
||||
return model.model_config
|
||||
return model.__config__ # type: ignore
|
||||
|
||||
|
||||
def get_model_fields(model: type[pydantic.BaseModel]) -> dict[str, FieldInfo]:
|
||||
if PYDANTIC_V2:
|
||||
return model.model_fields
|
||||
return model.__fields__ # type: ignore
|
||||
|
||||
|
||||
def model_copy(model: _ModelT, *, deep: bool = False) -> _ModelT:
|
||||
if PYDANTIC_V2:
|
||||
return model.model_copy(deep=deep)
|
||||
return model.copy(deep=deep) # type: ignore
|
||||
|
||||
|
||||
def model_json(model: pydantic.BaseModel, *, indent: int | None = None) -> str:
|
||||
if PYDANTIC_V2:
|
||||
return model.model_dump_json(indent=indent)
|
||||
return model.json(indent=indent) # type: ignore
|
||||
|
||||
|
||||
def model_dump(
|
||||
model: pydantic.BaseModel,
|
||||
*,
|
||||
exclude: IncEx | None = None,
|
||||
exclude_unset: bool = False,
|
||||
exclude_defaults: bool = False,
|
||||
warnings: bool = True,
|
||||
mode: Literal["json", "python"] = "python",
|
||||
) -> dict[str, Any]:
|
||||
if PYDANTIC_V2 or hasattr(model, "model_dump"):
|
||||
return model.model_dump(
|
||||
mode=mode,
|
||||
exclude=exclude,
|
||||
exclude_unset=exclude_unset,
|
||||
exclude_defaults=exclude_defaults,
|
||||
# warnings are not supported in Pydantic v1
|
||||
warnings=warnings if PYDANTIC_V2 else True,
|
||||
)
|
||||
return cast(
|
||||
"dict[str, Any]",
|
||||
model.dict( # pyright: ignore[reportDeprecated, reportUnnecessaryCast]
|
||||
exclude=exclude,
|
||||
exclude_unset=exclude_unset,
|
||||
exclude_defaults=exclude_defaults,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def model_parse(model: type[_ModelT], data: Any) -> _ModelT:
|
||||
if PYDANTIC_V2:
|
||||
return model.model_validate(data)
|
||||
return model.parse_obj(data) # pyright: ignore[reportDeprecated]
|
||||
|
||||
|
||||
# generic models
|
||||
if TYPE_CHECKING:
|
||||
|
||||
class GenericModel(pydantic.BaseModel): ...
|
||||
|
||||
else:
|
||||
if PYDANTIC_V2:
|
||||
# there no longer needs to be a distinction in v2 but
|
||||
# we still have to create our own subclass to avoid
|
||||
# inconsistent MRO ordering errors
|
||||
class GenericModel(pydantic.BaseModel): ...
|
||||
|
||||
else:
|
||||
import pydantic.generics
|
||||
|
||||
class GenericModel(pydantic.generics.GenericModel, pydantic.BaseModel): ...
|
||||
|
||||
|
||||
# cached properties
|
||||
if TYPE_CHECKING:
|
||||
cached_property = property
|
||||
|
||||
# we define a separate type (copied from typeshed)
|
||||
# that represents that `cached_property` is `set`able
|
||||
# at runtime, which differs from `@property`.
|
||||
#
|
||||
# this is a separate type as editors likely special case
|
||||
# `@property` and we don't want to cause issues just to have
|
||||
# more helpful internal types.
|
||||
|
||||
class typed_cached_property(Generic[_T]):
|
||||
func: Callable[[Any], _T]
|
||||
attrname: str | None
|
||||
|
||||
def __init__(self, func: Callable[[Any], _T]) -> None: ...
|
||||
|
||||
@overload
|
||||
def __get__(self, instance: None, owner: type[Any] | None = None) -> Self: ...
|
||||
|
||||
@overload
|
||||
def __get__(self, instance: object, owner: type[Any] | None = None) -> _T: ...
|
||||
|
||||
def __get__(self, instance: object, owner: type[Any] | None = None) -> _T | Self:
|
||||
raise NotImplementedError()
|
||||
|
||||
def __set_name__(self, owner: type[Any], name: str) -> None: ...
|
||||
|
||||
# __set__ is not defined at runtime, but @cached_property is designed to be settable
|
||||
def __set__(self, instance: object, value: _T) -> None: ...
|
||||
else:
|
||||
from functools import cached_property as cached_property
|
||||
|
||||
typed_cached_property = cached_property
|
||||
14
src/tinker/_constants.py
Normal file
14
src/tinker/_constants.py
Normal file
|
|
@ -0,0 +1,14 @@
|
|||
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
||||
|
||||
import httpx
|
||||
|
||||
RAW_RESPONSE_HEADER = "X-Stainless-Raw-Response"
|
||||
OVERRIDE_CAST_TO_HEADER = "____stainless_override_cast_to"
|
||||
|
||||
# default timeout is 1 minute
|
||||
DEFAULT_TIMEOUT = httpx.Timeout(timeout=60, connect=5.0)
|
||||
DEFAULT_MAX_RETRIES = 10
|
||||
DEFAULT_CONNECTION_LIMITS = httpx.Limits(max_connections=1000, max_keepalive_connections=20)
|
||||
|
||||
INITIAL_RETRY_DELAY = 0.5
|
||||
MAX_RETRY_DELAY = 10.0
|
||||
108
src/tinker/_exceptions.py
Normal file
108
src/tinker/_exceptions.py
Normal file
|
|
@ -0,0 +1,108 @@
|
|||
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing_extensions import Literal
|
||||
|
||||
import httpx
|
||||
|
||||
__all__ = [
|
||||
"BadRequestError",
|
||||
"AuthenticationError",
|
||||
"PermissionDeniedError",
|
||||
"NotFoundError",
|
||||
"ConflictError",
|
||||
"UnprocessableEntityError",
|
||||
"RateLimitError",
|
||||
"InternalServerError",
|
||||
]
|
||||
|
||||
|
||||
class TinkerError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class APIError(TinkerError):
|
||||
message: str
|
||||
request: httpx.Request
|
||||
|
||||
body: object | None
|
||||
"""The API response body.
|
||||
|
||||
If the API responded with a valid JSON structure then this property will be the
|
||||
decoded result.
|
||||
|
||||
If it isn't a valid JSON structure then this will be the raw response.
|
||||
|
||||
If there was no response associated with this error then it will be `None`.
|
||||
"""
|
||||
|
||||
def __init__(self, message: str, request: httpx.Request, *, body: object | None) -> None: # noqa: ARG002
|
||||
super().__init__(message)
|
||||
self.request = request
|
||||
self.message = message
|
||||
self.body = body
|
||||
|
||||
|
||||
class APIResponseValidationError(APIError):
|
||||
response: httpx.Response
|
||||
status_code: int
|
||||
|
||||
def __init__(self, response: httpx.Response, body: object | None, *, message: str | None = None) -> None:
|
||||
super().__init__(message or "Data returned by API invalid for expected schema.", response.request, body=body)
|
||||
self.response = response
|
||||
self.status_code = response.status_code
|
||||
|
||||
|
||||
class APIStatusError(APIError):
|
||||
"""Raised when an API response has a status code of 4xx or 5xx."""
|
||||
|
||||
response: httpx.Response
|
||||
status_code: int
|
||||
|
||||
def __init__(self, message: str, *, response: httpx.Response, body: object | None) -> None:
|
||||
super().__init__(message, response.request, body=body)
|
||||
self.response = response
|
||||
self.status_code = response.status_code
|
||||
|
||||
|
||||
class APIConnectionError(APIError):
|
||||
def __init__(self, *, message: str = "Connection error.", request: httpx.Request) -> None:
|
||||
super().__init__(message, request, body=None)
|
||||
|
||||
|
||||
class APITimeoutError(APIConnectionError):
|
||||
def __init__(self, request: httpx.Request) -> None:
|
||||
super().__init__(message="Request timed out.", request=request)
|
||||
|
||||
|
||||
class BadRequestError(APIStatusError):
|
||||
status_code: Literal[400] = 400 # pyright: ignore[reportIncompatibleVariableOverride]
|
||||
|
||||
|
||||
class AuthenticationError(APIStatusError):
|
||||
status_code: Literal[401] = 401 # pyright: ignore[reportIncompatibleVariableOverride]
|
||||
|
||||
|
||||
class PermissionDeniedError(APIStatusError):
|
||||
status_code: Literal[403] = 403 # pyright: ignore[reportIncompatibleVariableOverride]
|
||||
|
||||
|
||||
class NotFoundError(APIStatusError):
|
||||
status_code: Literal[404] = 404 # pyright: ignore[reportIncompatibleVariableOverride]
|
||||
|
||||
|
||||
class ConflictError(APIStatusError):
|
||||
status_code: Literal[409] = 409 # pyright: ignore[reportIncompatibleVariableOverride]
|
||||
|
||||
|
||||
class UnprocessableEntityError(APIStatusError):
|
||||
status_code: Literal[422] = 422 # pyright: ignore[reportIncompatibleVariableOverride]
|
||||
|
||||
|
||||
class RateLimitError(APIStatusError):
|
||||
status_code: Literal[429] = 429 # pyright: ignore[reportIncompatibleVariableOverride]
|
||||
|
||||
|
||||
class InternalServerError(APIStatusError):
|
||||
pass
|
||||
123
src/tinker/_files.py
Normal file
123
src/tinker/_files.py
Normal file
|
|
@ -0,0 +1,123 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
import os
|
||||
import pathlib
|
||||
from typing import overload
|
||||
from typing_extensions import TypeGuard
|
||||
|
||||
import anyio
|
||||
|
||||
from ._types import (
|
||||
FileTypes,
|
||||
FileContent,
|
||||
RequestFiles,
|
||||
HttpxFileTypes,
|
||||
Base64FileInput,
|
||||
HttpxFileContent,
|
||||
HttpxRequestFiles,
|
||||
)
|
||||
from ._utils import is_tuple_t, is_mapping_t, is_sequence_t
|
||||
|
||||
|
||||
def is_base64_file_input(obj: object) -> TypeGuard[Base64FileInput]:
|
||||
return isinstance(obj, io.IOBase) or isinstance(obj, os.PathLike)
|
||||
|
||||
|
||||
def is_file_content(obj: object) -> TypeGuard[FileContent]:
|
||||
return (
|
||||
isinstance(obj, bytes) or isinstance(obj, tuple) or isinstance(obj, io.IOBase) or isinstance(obj, os.PathLike)
|
||||
)
|
||||
|
||||
|
||||
def assert_is_file_content(obj: object, *, key: str | None = None) -> None:
|
||||
if not is_file_content(obj):
|
||||
prefix = f"Expected entry at `{key}`" if key is not None else f"Expected file input `{obj!r}`"
|
||||
raise RuntimeError(
|
||||
f"{prefix} to be bytes, an io.IOBase instance, PathLike or a tuple but received {type(obj)} instead."
|
||||
) from None
|
||||
|
||||
|
||||
@overload
|
||||
def to_httpx_files(files: None) -> None: ...
|
||||
|
||||
|
||||
@overload
|
||||
def to_httpx_files(files: RequestFiles) -> HttpxRequestFiles: ...
|
||||
|
||||
|
||||
def to_httpx_files(files: RequestFiles | None) -> HttpxRequestFiles | None:
|
||||
if files is None:
|
||||
return None
|
||||
|
||||
if is_mapping_t(files):
|
||||
files = {key: _transform_file(file) for key, file in files.items()}
|
||||
elif is_sequence_t(files):
|
||||
files = [(key, _transform_file(file)) for key, file in files]
|
||||
else:
|
||||
raise TypeError(f"Unexpected file type input {type(files)}, expected mapping or sequence")
|
||||
|
||||
return files
|
||||
|
||||
|
||||
def _transform_file(file: FileTypes) -> HttpxFileTypes:
|
||||
if is_file_content(file):
|
||||
if isinstance(file, os.PathLike):
|
||||
path = pathlib.Path(file)
|
||||
return (path.name, path.read_bytes())
|
||||
|
||||
return file
|
||||
|
||||
if is_tuple_t(file):
|
||||
return (file[0], read_file_content(file[1]), *file[2:])
|
||||
|
||||
raise TypeError(f"Expected file types input to be a FileContent type or to be a tuple")
|
||||
|
||||
|
||||
def read_file_content(file: FileContent) -> HttpxFileContent:
|
||||
if isinstance(file, os.PathLike):
|
||||
return pathlib.Path(file).read_bytes()
|
||||
return file
|
||||
|
||||
|
||||
@overload
|
||||
async def async_to_httpx_files(files: None) -> None: ...
|
||||
|
||||
|
||||
@overload
|
||||
async def async_to_httpx_files(files: RequestFiles) -> HttpxRequestFiles: ...
|
||||
|
||||
|
||||
async def async_to_httpx_files(files: RequestFiles | None) -> HttpxRequestFiles | None:
|
||||
if files is None:
|
||||
return None
|
||||
|
||||
if is_mapping_t(files):
|
||||
files = {key: await _async_transform_file(file) for key, file in files.items()}
|
||||
elif is_sequence_t(files):
|
||||
files = [(key, await _async_transform_file(file)) for key, file in files]
|
||||
else:
|
||||
raise TypeError("Unexpected file type input {type(files)}, expected mapping or sequence")
|
||||
|
||||
return files
|
||||
|
||||
|
||||
async def _async_transform_file(file: FileTypes) -> HttpxFileTypes:
|
||||
if is_file_content(file):
|
||||
if isinstance(file, os.PathLike):
|
||||
path = anyio.Path(file)
|
||||
return (path.name, await path.read_bytes())
|
||||
|
||||
return file
|
||||
|
||||
if is_tuple_t(file):
|
||||
return (file[0], await async_read_file_content(file[1]), *file[2:])
|
||||
|
||||
raise TypeError(f"Expected file types input to be a FileContent type or to be a tuple")
|
||||
|
||||
|
||||
async def async_read_file_content(file: FileContent) -> HttpxFileContent:
|
||||
if isinstance(file, os.PathLike):
|
||||
return await anyio.Path(file).read_bytes()
|
||||
|
||||
return file
|
||||
560
src/tinker/_models.py
Normal file
560
src/tinker/_models.py
Normal file
|
|
@ -0,0 +1,560 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
from typing import TYPE_CHECKING, Any, Type, Union, Generic, TypeVar, Callable, Optional, cast
|
||||
from datetime import date, datetime
|
||||
from typing import TYPE_CHECKING, Any, Callable, Generic, Type, TypeVar, Union, cast
|
||||
|
||||
import pydantic
|
||||
from pydantic.fields import FieldInfo
|
||||
from typing_extensions import (
|
||||
List,
|
||||
Unpack,
|
||||
Literal,
|
||||
ClassVar,
|
||||
ParamSpec,
|
||||
Protocol,
|
||||
Required,
|
||||
TypedDict,
|
||||
TypeGuard,
|
||||
Unpack,
|
||||
final,
|
||||
runtime_checkable,
|
||||
)
|
||||
|
||||
from ._compat import (
|
||||
PYDANTIC_V2,
|
||||
ConfigDict,
|
||||
field_get_default,
|
||||
get_args,
|
||||
get_origin,
|
||||
is_literal_type,
|
||||
is_union,
|
||||
parse_obj,
|
||||
)
|
||||
from ._compat import (
|
||||
GenericModel as BaseGenericModel,
|
||||
)
|
||||
from ._constants import RAW_RESPONSE_HEADER
|
||||
from ._types import (
|
||||
AnyMapping,
|
||||
Body,
|
||||
Headers,
|
||||
HttpxRequestFiles,
|
||||
NotGiven,
|
||||
Query,
|
||||
Timeout,
|
||||
)
|
||||
from ._utils import (
|
||||
PropertyInfo,
|
||||
extract_type_arg,
|
||||
is_annotated_type,
|
||||
is_given,
|
||||
is_list,
|
||||
is_mapping,
|
||||
is_type_alias_type,
|
||||
lru_cache,
|
||||
parse_date,
|
||||
parse_datetime,
|
||||
strip_annotated_type,
|
||||
strip_not_given,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pydantic_core.core_schema import LiteralSchema, ModelField, ModelFieldsSchema, ModelSchema
|
||||
|
||||
__all__ = ["StrictBase", "BaseModel", "GenericModel"]
|
||||
|
||||
_T = TypeVar("_T")
|
||||
_BaseModelT = TypeVar("_BaseModelT", bound=pydantic.BaseModel)
|
||||
|
||||
P = ParamSpec("P")
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class _ConfigProtocol(Protocol):
|
||||
allow_population_by_field_name: bool
|
||||
|
||||
|
||||
class StrictBase(pydantic.BaseModel):
|
||||
"""
|
||||
Don't allow extra fields, so user errors are caught earlier.
|
||||
Use this for request types.
|
||||
"""
|
||||
model_config = ConfigDict(frozen=True, extra="forbid")
|
||||
|
||||
def __str__(self) -> str:
|
||||
return repr(self)
|
||||
|
||||
class BaseModel(pydantic.BaseModel):
|
||||
"""
|
||||
Use for classes that may appear in responses. Allow extra fields, so old clients can still work.
|
||||
"""
|
||||
|
||||
# For future-proofing, we ignore extra fields in case the server adds new fields.
|
||||
model_config = ConfigDict(frozen=True, extra="ignore")
|
||||
|
||||
def __str__(self) -> str:
|
||||
return repr(self)
|
||||
|
||||
|
||||
def _construct_field(value: object, field: FieldInfo, key: str) -> object:
|
||||
if value is None:
|
||||
return field_get_default(field)
|
||||
|
||||
if PYDANTIC_V2:
|
||||
type_ = field.annotation
|
||||
else:
|
||||
type_ = cast(type, field.outer_type_) # type: ignore
|
||||
|
||||
if type_ is None:
|
||||
raise RuntimeError(f"Unexpected field type is None for {key}")
|
||||
|
||||
return construct_type(value=value, type_=type_, metadata=getattr(field, "metadata", None))
|
||||
|
||||
|
||||
def is_basemodel(type_: type) -> bool:
|
||||
"""Returns whether or not the given type is either a `BaseModel` or a union of `BaseModel`"""
|
||||
if is_union(type_):
|
||||
for variant in get_args(type_):
|
||||
if is_basemodel(variant):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
return is_basemodel_type(type_)
|
||||
|
||||
|
||||
def is_basemodel_type(type_: type) -> TypeGuard[type[BaseModel] | type[GenericModel]]:
|
||||
origin = get_origin(type_) or type_
|
||||
if not inspect.isclass(origin):
|
||||
return False
|
||||
return issubclass(origin, BaseModel) or issubclass(origin, GenericModel)
|
||||
|
||||
|
||||
def build(
|
||||
base_model_cls: Callable[P, _BaseModelT],
|
||||
*args: P.args,
|
||||
**kwargs: P.kwargs,
|
||||
) -> _BaseModelT:
|
||||
"""Construct a BaseModel class without validation.
|
||||
|
||||
This is useful for cases where you need to instantiate a `BaseModel`
|
||||
from an API response as this provides type-safe params which isn't supported
|
||||
by helpers like `construct_type()`.
|
||||
|
||||
```py
|
||||
build(MyModel, my_field_a="foo", my_field_b=123)
|
||||
```
|
||||
"""
|
||||
if args:
|
||||
raise TypeError(
|
||||
"Received positional arguments which are not supported; Keyword arguments must be used instead",
|
||||
)
|
||||
|
||||
return cast(_BaseModelT, construct_type(type_=base_model_cls, value=kwargs))
|
||||
|
||||
|
||||
def construct_type_unchecked(*, value: object, type_: type[_T]) -> _T:
|
||||
"""Loose coercion to the expected type with construction of nested values.
|
||||
|
||||
Note: the returned value from this function is not guaranteed to match the
|
||||
given type.
|
||||
"""
|
||||
return cast(_T, construct_type(value=value, type_=type_))
|
||||
|
||||
|
||||
def construct_type(*, value: object, type_: object, metadata: Optional[List[Any]] = None) -> object:
|
||||
"""Loose coercion to the expected type with construction of nested values.
|
||||
|
||||
If the given value does not match the expected type then it is returned as-is.
|
||||
"""
|
||||
|
||||
# store a reference to the original type we were given before we extract any inner
|
||||
# types so that we can properly resolve forward references in `TypeAliasType` annotations
|
||||
original_type = None
|
||||
|
||||
# we allow `object` as the input type because otherwise, passing things like
|
||||
# `Literal['value']` will be reported as a type error by type checkers
|
||||
type_ = cast("type[object]", type_)
|
||||
if is_type_alias_type(type_):
|
||||
original_type = type_ # type: ignore[unreachable]
|
||||
type_ = type_.__value__ # type: ignore[unreachable]
|
||||
|
||||
# unwrap `Annotated[T, ...]` -> `T`
|
||||
if metadata is not None and len(metadata) > 0:
|
||||
meta: tuple[Any, ...] = tuple(metadata)
|
||||
elif is_annotated_type(type_):
|
||||
meta = get_args(type_)[1:]
|
||||
type_ = extract_type_arg(type_, 0)
|
||||
else:
|
||||
meta = tuple()
|
||||
|
||||
# we need to use the origin class for any types that are subscripted generics
|
||||
# e.g. Dict[str, object]
|
||||
origin = get_origin(type_) or type_
|
||||
args = get_args(type_)
|
||||
|
||||
if is_union(origin):
|
||||
try:
|
||||
return validate_type(type_=cast("type[object]", original_type or type_), value=value)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# if the type is a discriminated union then we want to construct the right variant
|
||||
# in the union, even if the data doesn't match exactly, otherwise we'd break code
|
||||
# that relies on the constructed class types, e.g.
|
||||
#
|
||||
# class FooType:
|
||||
# kind: Literal['foo']
|
||||
# value: str
|
||||
#
|
||||
# class BarType:
|
||||
# kind: Literal['bar']
|
||||
# value: int
|
||||
#
|
||||
# without this block, if the data we get is something like `{'kind': 'bar', 'value': 'foo'}` then
|
||||
# we'd end up constructing `FooType` when it should be `BarType`.
|
||||
discriminator = _build_discriminated_union_meta(union=type_, meta_annotations=meta)
|
||||
if discriminator and is_mapping(value):
|
||||
variant_value = value.get(discriminator.field_alias_from or discriminator.field_name)
|
||||
if variant_value and isinstance(variant_value, str):
|
||||
variant_type = discriminator.mapping.get(variant_value)
|
||||
if variant_type:
|
||||
return construct_type(type_=variant_type, value=value)
|
||||
|
||||
# if the data is not valid, use the first variant that doesn't fail while deserializing
|
||||
for variant in args:
|
||||
try:
|
||||
return construct_type(value=value, type_=variant)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
raise RuntimeError(f"Could not convert data into a valid instance of {type_}")
|
||||
|
||||
if origin == dict:
|
||||
if not is_mapping(value):
|
||||
return value
|
||||
|
||||
_, items_type = get_args(type_) # Dict[_, items_type]
|
||||
return {key: construct_type(value=item, type_=items_type) for key, item in value.items()}
|
||||
|
||||
if (
|
||||
not is_literal_type(type_)
|
||||
and inspect.isclass(origin)
|
||||
and (issubclass(origin, BaseModel) or issubclass(origin, GenericModel))
|
||||
):
|
||||
if is_list(value):
|
||||
return [
|
||||
cast(Any, type_).construct(**entry) if is_mapping(entry) else entry
|
||||
for entry in value
|
||||
]
|
||||
|
||||
if is_mapping(value):
|
||||
if issubclass(type_, BaseModel):
|
||||
return type_.construct(**value) # type: ignore[arg-type]
|
||||
|
||||
return cast(Any, type_).construct(**value)
|
||||
|
||||
if origin == list:
|
||||
if not is_list(value):
|
||||
return value
|
||||
|
||||
inner_type = args[0] # List[inner_type]
|
||||
return [construct_type(value=entry, type_=inner_type) for entry in value]
|
||||
|
||||
if origin == float:
|
||||
if isinstance(value, int):
|
||||
coerced = float(value)
|
||||
if coerced != value:
|
||||
return value
|
||||
return coerced
|
||||
|
||||
return value
|
||||
|
||||
if type_ == datetime:
|
||||
try:
|
||||
return parse_datetime(value) # type: ignore
|
||||
except Exception:
|
||||
return value
|
||||
|
||||
if type_ == date:
|
||||
try:
|
||||
return parse_date(value) # type: ignore
|
||||
except Exception:
|
||||
return value
|
||||
|
||||
return value
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class CachedDiscriminatorType(Protocol):
|
||||
__discriminator__: DiscriminatorDetails
|
||||
|
||||
|
||||
class DiscriminatorDetails:
|
||||
field_name: str
|
||||
"""The name of the discriminator field in the variant class, e.g.
|
||||
|
||||
```py
|
||||
class Foo(BaseModel):
|
||||
type: Literal['foo']
|
||||
```
|
||||
|
||||
Will result in field_name='type'
|
||||
"""
|
||||
|
||||
field_alias_from: str | None
|
||||
"""The name of the discriminator field in the API response, e.g.
|
||||
|
||||
```py
|
||||
class Foo(BaseModel):
|
||||
type: Literal['foo'] = Field(alias='type_from_api')
|
||||
```
|
||||
|
||||
Will result in field_alias_from='type_from_api'
|
||||
"""
|
||||
|
||||
mapping: dict[str, type]
|
||||
"""Mapping of discriminator value to variant type, e.g.
|
||||
|
||||
{'foo': FooVariant, 'bar': BarVariant}
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
mapping: dict[str, type],
|
||||
discriminator_field: str,
|
||||
discriminator_alias: str | None,
|
||||
) -> None:
|
||||
self.mapping = mapping
|
||||
self.field_name = discriminator_field
|
||||
self.field_alias_from = discriminator_alias
|
||||
|
||||
|
||||
def _build_discriminated_union_meta(
|
||||
*, union: type, meta_annotations: tuple[Any, ...]
|
||||
) -> DiscriminatorDetails | None:
|
||||
if isinstance(union, CachedDiscriminatorType):
|
||||
return union.__discriminator__
|
||||
|
||||
discriminator_field_name: str | None = None
|
||||
|
||||
for annotation in meta_annotations:
|
||||
if isinstance(annotation, PropertyInfo) and annotation.discriminator is not None:
|
||||
discriminator_field_name = annotation.discriminator
|
||||
break
|
||||
|
||||
if not discriminator_field_name:
|
||||
return None
|
||||
|
||||
mapping: dict[str, type] = {}
|
||||
discriminator_alias: str | None = None
|
||||
|
||||
for variant in get_args(union):
|
||||
variant = strip_annotated_type(variant)
|
||||
if is_basemodel_type(variant):
|
||||
if PYDANTIC_V2:
|
||||
field = _extract_field_schema_pv2(variant, discriminator_field_name)
|
||||
if not field:
|
||||
continue
|
||||
|
||||
# Note: if one variant defines an alias then they all should
|
||||
discriminator_alias = field.get("serialization_alias")
|
||||
|
||||
field_schema = field["schema"]
|
||||
|
||||
if field_schema["type"] == "literal":
|
||||
for entry in cast("LiteralSchema", field_schema)["expected"]:
|
||||
if isinstance(entry, str):
|
||||
mapping[entry] = variant
|
||||
else:
|
||||
field_info = cast("dict[str, FieldInfo]", variant.__fields__).get(
|
||||
discriminator_field_name
|
||||
) # pyright: ignore[reportDeprecated, reportUnnecessaryCast]
|
||||
if not field_info:
|
||||
continue
|
||||
|
||||
# Note: if one variant defines an alias then they all should
|
||||
discriminator_alias = field_info.alias
|
||||
|
||||
if (annotation := getattr(field_info, "annotation", None)) and is_literal_type(
|
||||
annotation
|
||||
):
|
||||
for entry in get_args(annotation):
|
||||
if isinstance(entry, str):
|
||||
mapping[entry] = variant
|
||||
|
||||
if not mapping:
|
||||
return None
|
||||
|
||||
details = DiscriminatorDetails(
|
||||
mapping=mapping,
|
||||
discriminator_field=discriminator_field_name,
|
||||
discriminator_alias=discriminator_alias,
|
||||
)
|
||||
cast(CachedDiscriminatorType, union).__discriminator__ = details
|
||||
return details
|
||||
|
||||
|
||||
def _extract_field_schema_pv2(model: type[BaseModel], field_name: str) -> ModelField | None:
|
||||
schema = model.__pydantic_core_schema__
|
||||
if schema["type"] == "definitions":
|
||||
schema = schema["schema"]
|
||||
|
||||
if schema["type"] != "model":
|
||||
return None
|
||||
|
||||
schema = cast("ModelSchema", schema)
|
||||
fields_schema = schema["schema"]
|
||||
if fields_schema["type"] != "model-fields":
|
||||
return None
|
||||
|
||||
fields_schema = cast("ModelFieldsSchema", fields_schema)
|
||||
field = fields_schema["fields"].get(field_name)
|
||||
if not field:
|
||||
return None
|
||||
|
||||
return cast("ModelField", field) # pyright: ignore[reportUnnecessaryCast]
|
||||
|
||||
|
||||
def validate_type(*, type_: type[_T], value: object) -> _T:
|
||||
"""Strict validation that the given value matches the expected type"""
|
||||
if inspect.isclass(type_) and issubclass(type_, pydantic.BaseModel):
|
||||
return cast(_T, parse_obj(type_, value))
|
||||
|
||||
return cast(_T, _validate_non_model_type(type_=type_, value=value))
|
||||
|
||||
|
||||
def set_pydantic_config(typ: Any, config: pydantic.ConfigDict) -> None:
|
||||
"""Add a pydantic config for the given type.
|
||||
|
||||
Note: this is a no-op on Pydantic v1.
|
||||
"""
|
||||
setattr(typ, "__pydantic_config__", config) # noqa: B010
|
||||
|
||||
|
||||
# our use of subclassing here causes weirdness for type checkers,
|
||||
# so we just pretend that we don't subclass
|
||||
if TYPE_CHECKING:
|
||||
GenericModel = BaseModel
|
||||
else:
|
||||
|
||||
class GenericModel(BaseGenericModel, BaseModel):
|
||||
pass
|
||||
|
||||
|
||||
if PYDANTIC_V2:
|
||||
from pydantic import TypeAdapter as _TypeAdapter
|
||||
|
||||
_CachedTypeAdapter = cast("TypeAdapter[object]", lru_cache(maxsize=None)(_TypeAdapter))
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pydantic import TypeAdapter
|
||||
else:
|
||||
TypeAdapter = _CachedTypeAdapter
|
||||
|
||||
def _validate_non_model_type(*, type_: type[_T], value: object) -> _T:
|
||||
return TypeAdapter(type_).validate_python(value)
|
||||
|
||||
elif not TYPE_CHECKING: # TODO: condition is weird
|
||||
|
||||
class RootModel(GenericModel, Generic[_T]):
|
||||
"""Used as a placeholder to easily convert runtime types to a Pydantic format
|
||||
to provide validation.
|
||||
|
||||
For example:
|
||||
```py
|
||||
validated = RootModel[int](__root__="5").__root__
|
||||
# validated: 5
|
||||
```
|
||||
"""
|
||||
|
||||
__root__: _T
|
||||
|
||||
def _validate_non_model_type(*, type_: type[_T], value: object) -> _T:
|
||||
model = _create_pydantic_model(type_).validate(value)
|
||||
return cast(_T, model.__root__)
|
||||
|
||||
def _create_pydantic_model(type_: _T) -> Type[RootModel[_T]]:
|
||||
return RootModel[type_] # type: ignore
|
||||
|
||||
|
||||
class FinalRequestOptionsInput(TypedDict, total=False):
|
||||
method: Required[str]
|
||||
url: Required[str]
|
||||
params: Query
|
||||
headers: Headers
|
||||
max_retries: int
|
||||
timeout: float | Timeout | None
|
||||
files: HttpxRequestFiles | None
|
||||
idempotency_key: str
|
||||
json_data: Body
|
||||
extra_json: AnyMapping
|
||||
follow_redirects: bool
|
||||
|
||||
|
||||
@final
|
||||
class FinalRequestOptions(pydantic.BaseModel):
|
||||
method: str
|
||||
url: str
|
||||
params: Query = {}
|
||||
headers: Union[Headers, NotGiven] = NotGiven()
|
||||
max_retries: Union[int, NotGiven] = NotGiven()
|
||||
timeout: Union[float, Timeout, None, NotGiven] = NotGiven()
|
||||
files: Union[HttpxRequestFiles, None] = None
|
||||
idempotency_key: Union[str, None] = None
|
||||
post_parser: Union[Callable[[Any], Any], NotGiven] = NotGiven()
|
||||
follow_redirects: Union[bool, None] = None
|
||||
|
||||
# It should be noted that we cannot use `json` here as that would override
|
||||
# a BaseModel method in an incompatible fashion.
|
||||
json_data: Union[Body, None] = None
|
||||
extra_json: Union[AnyMapping, None] = None
|
||||
|
||||
if PYDANTIC_V2:
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(arbitrary_types_allowed=True)
|
||||
else:
|
||||
|
||||
class Config(pydantic.BaseConfig): # pyright: ignore[reportDeprecated]
|
||||
arbitrary_types_allowed: bool = True
|
||||
|
||||
def get_max_retries(self, max_retries: int) -> int:
|
||||
if isinstance(self.max_retries, NotGiven):
|
||||
return max_retries
|
||||
return self.max_retries
|
||||
|
||||
def _strip_raw_response_header(self) -> None:
|
||||
if not is_given(self.headers):
|
||||
return
|
||||
|
||||
if self.headers.get(RAW_RESPONSE_HEADER):
|
||||
self.headers = {**self.headers}
|
||||
self.headers.pop(RAW_RESPONSE_HEADER)
|
||||
|
||||
# override the `construct` method so that we can run custom transformations.
|
||||
# this is necessary as we don't want to do any actual runtime type checking
|
||||
# (which means we can't use validators) but we do want to ensure that `NotGiven`
|
||||
# values are not present
|
||||
#
|
||||
# type ignore required because we're adding explicit types to `**values`
|
||||
@classmethod
|
||||
def construct( # type: ignore
|
||||
cls,
|
||||
_fields_set: set[str] | None = None,
|
||||
**values: Unpack[FinalRequestOptionsInput],
|
||||
) -> FinalRequestOptions:
|
||||
kwargs: dict[str, Any] = {
|
||||
# we unconditionally call `strip_not_given` on any value
|
||||
# as it will just ignore any non-mapping types
|
||||
key: strip_not_given(value)
|
||||
for key, value in values.items()
|
||||
}
|
||||
if PYDANTIC_V2:
|
||||
return super().model_construct(_fields_set, **kwargs)
|
||||
return cast(FinalRequestOptions, super().construct(_fields_set, **kwargs)) # pyright: ignore[reportDeprecated]
|
||||
|
||||
if not TYPE_CHECKING:
|
||||
# type checkers incorrectly complain about this assignment
|
||||
model_construct = construct
|
||||
150
src/tinker/_qs.py
Normal file
150
src/tinker/_qs.py
Normal file
|
|
@ -0,0 +1,150 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import Any, List, Tuple, Union, Mapping, TypeVar
|
||||
from urllib.parse import parse_qs, urlencode
|
||||
from typing_extensions import Literal, get_args
|
||||
|
||||
from ._types import NOT_GIVEN, NotGiven, NotGivenOr
|
||||
from ._utils import flatten
|
||||
|
||||
_T = TypeVar("_T")
|
||||
|
||||
|
||||
ArrayFormat = Literal["comma", "repeat", "indices", "brackets"]
|
||||
NestedFormat = Literal["dots", "brackets"]
|
||||
|
||||
PrimitiveData = Union[str, int, float, bool, None]
|
||||
# this should be Data = Union[PrimitiveData, "List[Data]", "Tuple[Data]", "Mapping[str, Data]"]
|
||||
# https://github.com/microsoft/pyright/issues/3555
|
||||
Data = Union[PrimitiveData, List[Any], Tuple[Any], "Mapping[str, Any]"]
|
||||
Params = Mapping[str, Data]
|
||||
|
||||
|
||||
class Querystring:
|
||||
array_format: ArrayFormat
|
||||
nested_format: NestedFormat
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
array_format: ArrayFormat = "repeat",
|
||||
nested_format: NestedFormat = "brackets",
|
||||
) -> None:
|
||||
self.array_format = array_format
|
||||
self.nested_format = nested_format
|
||||
|
||||
def parse(self, query: str) -> Mapping[str, object]:
|
||||
# Note: custom format syntax is not supported yet
|
||||
return parse_qs(query)
|
||||
|
||||
def stringify(
|
||||
self,
|
||||
params: Params,
|
||||
*,
|
||||
array_format: NotGivenOr[ArrayFormat] = NOT_GIVEN,
|
||||
nested_format: NotGivenOr[NestedFormat] = NOT_GIVEN,
|
||||
) -> str:
|
||||
return urlencode(
|
||||
self.stringify_items(
|
||||
params,
|
||||
array_format=array_format,
|
||||
nested_format=nested_format,
|
||||
)
|
||||
)
|
||||
|
||||
def stringify_items(
|
||||
self,
|
||||
params: Params,
|
||||
*,
|
||||
array_format: NotGivenOr[ArrayFormat] = NOT_GIVEN,
|
||||
nested_format: NotGivenOr[NestedFormat] = NOT_GIVEN,
|
||||
) -> list[tuple[str, str]]:
|
||||
opts = Options(
|
||||
qs=self,
|
||||
array_format=array_format,
|
||||
nested_format=nested_format,
|
||||
)
|
||||
return flatten([self._stringify_item(key, value, opts) for key, value in params.items()])
|
||||
|
||||
def _stringify_item(
|
||||
self,
|
||||
key: str,
|
||||
value: Data,
|
||||
opts: Options,
|
||||
) -> list[tuple[str, str]]:
|
||||
if isinstance(value, Mapping):
|
||||
items: list[tuple[str, str]] = []
|
||||
nested_format = opts.nested_format
|
||||
for subkey, subvalue in value.items():
|
||||
items.extend(
|
||||
self._stringify_item(
|
||||
# TODO: error if unknown format
|
||||
f"{key}.{subkey}" if nested_format == "dots" else f"{key}[{subkey}]",
|
||||
subvalue,
|
||||
opts,
|
||||
)
|
||||
)
|
||||
return items
|
||||
|
||||
if isinstance(value, (list, tuple)):
|
||||
array_format = opts.array_format
|
||||
if array_format == "comma":
|
||||
return [
|
||||
(
|
||||
key,
|
||||
",".join(self._primitive_value_to_str(item) for item in value if item is not None),
|
||||
),
|
||||
]
|
||||
elif array_format == "repeat":
|
||||
items = []
|
||||
for item in value:
|
||||
items.extend(self._stringify_item(key, item, opts))
|
||||
return items
|
||||
elif array_format == "indices":
|
||||
raise NotImplementedError("The array indices format is not supported yet")
|
||||
elif array_format == "brackets":
|
||||
items = []
|
||||
key = key + "[]"
|
||||
for item in value:
|
||||
items.extend(self._stringify_item(key, item, opts))
|
||||
return items
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Unknown array_format value: {array_format}, choose from {', '.join(get_args(ArrayFormat))}"
|
||||
)
|
||||
|
||||
serialised = self._primitive_value_to_str(value)
|
||||
if not serialised:
|
||||
return []
|
||||
return [(key, serialised)]
|
||||
|
||||
def _primitive_value_to_str(self, value: PrimitiveData) -> str:
|
||||
# copied from httpx
|
||||
if value is True:
|
||||
return "true"
|
||||
elif value is False:
|
||||
return "false"
|
||||
elif value is None:
|
||||
return ""
|
||||
return str(value)
|
||||
|
||||
|
||||
_qs = Querystring()
|
||||
parse = _qs.parse
|
||||
stringify = _qs.stringify
|
||||
stringify_items = _qs.stringify_items
|
||||
|
||||
|
||||
class Options:
|
||||
array_format: ArrayFormat
|
||||
nested_format: NestedFormat
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
qs: Querystring = _qs,
|
||||
*,
|
||||
array_format: NotGivenOr[ArrayFormat] = NOT_GIVEN,
|
||||
nested_format: NotGivenOr[NestedFormat] = NOT_GIVEN,
|
||||
) -> None:
|
||||
self.array_format = qs.array_format if isinstance(array_format, NotGiven) else array_format
|
||||
self.nested_format = qs.nested_format if isinstance(nested_format, NotGiven) else nested_format
|
||||
43
src/tinker/_resource.py
Normal file
43
src/tinker/_resource.py
Normal file
|
|
@ -0,0 +1,43 @@
|
|||
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import anyio
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ._client import Tinker, AsyncTinker
|
||||
|
||||
|
||||
class SyncAPIResource:
|
||||
_client: Tinker
|
||||
|
||||
def __init__(self, client: Tinker) -> None:
|
||||
self._client = client
|
||||
self._get = client.get
|
||||
self._post = client.post
|
||||
self._patch = client.patch
|
||||
self._put = client.put
|
||||
self._delete = client.delete
|
||||
self._get_api_list = client.get_api_list
|
||||
|
||||
def _sleep(self, seconds: float) -> None:
|
||||
time.sleep(seconds)
|
||||
|
||||
|
||||
class AsyncAPIResource:
|
||||
_client: AsyncTinker
|
||||
|
||||
def __init__(self, client: AsyncTinker) -> None:
|
||||
self._client = client
|
||||
self._get = client.get
|
||||
self._post = client.post
|
||||
self._patch = client.patch
|
||||
self._put = client.put
|
||||
self._delete = client.delete
|
||||
self._get_api_list = client.get_api_list
|
||||
|
||||
async def _sleep(self, seconds: float) -> None:
|
||||
await anyio.sleep(seconds)
|
||||
830
src/tinker/_response.py
Normal file
830
src/tinker/_response.py
Normal file
|
|
@ -0,0 +1,830 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import inspect
|
||||
import logging
|
||||
import datetime
|
||||
import functools
|
||||
from types import TracebackType
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Union,
|
||||
Generic,
|
||||
TypeVar,
|
||||
Callable,
|
||||
Iterator,
|
||||
AsyncIterator,
|
||||
cast,
|
||||
overload,
|
||||
)
|
||||
from typing_extensions import Awaitable, ParamSpec, override, get_origin
|
||||
|
||||
import anyio
|
||||
import httpx
|
||||
import pydantic
|
||||
|
||||
from ._types import NoneType
|
||||
from ._utils import is_given, extract_type_arg, is_annotated_type, is_type_alias_type, extract_type_var_from_base
|
||||
from ._models import BaseModel, is_basemodel
|
||||
from ._constants import RAW_RESPONSE_HEADER, OVERRIDE_CAST_TO_HEADER
|
||||
from ._streaming import Stream, AsyncStream, is_stream_class_type, extract_stream_chunk_type
|
||||
from ._exceptions import TinkerError, APIResponseValidationError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ._models import FinalRequestOptions
|
||||
from ._base_client import BaseClient
|
||||
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
_T = TypeVar("_T")
|
||||
_APIResponseT = TypeVar("_APIResponseT", bound="APIResponse[Any]")
|
||||
_AsyncAPIResponseT = TypeVar("_AsyncAPIResponseT", bound="AsyncAPIResponse[Any]")
|
||||
|
||||
log: logging.Logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseAPIResponse(Generic[R]):
|
||||
_cast_to: type[R]
|
||||
_client: BaseClient[Any, Any]
|
||||
_parsed_by_type: dict[type[Any], Any]
|
||||
_is_sse_stream: bool
|
||||
_stream_cls: type[Stream[Any]] | type[AsyncStream[Any]] | None
|
||||
_options: FinalRequestOptions
|
||||
|
||||
http_response: httpx.Response
|
||||
|
||||
retries_taken: int
|
||||
"""The number of retries made. If no retries happened this will be `0`"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
raw: httpx.Response,
|
||||
cast_to: type[R],
|
||||
client: BaseClient[Any, Any],
|
||||
stream: bool,
|
||||
stream_cls: type[Stream[Any]] | type[AsyncStream[Any]] | None,
|
||||
options: FinalRequestOptions,
|
||||
retries_taken: int = 0,
|
||||
) -> None:
|
||||
self._cast_to = cast_to
|
||||
self._client = client
|
||||
self._parsed_by_type = {}
|
||||
self._is_sse_stream = stream
|
||||
self._stream_cls = stream_cls
|
||||
self._options = options
|
||||
self.http_response = raw
|
||||
self.retries_taken = retries_taken
|
||||
|
||||
@property
|
||||
def headers(self) -> httpx.Headers:
|
||||
return self.http_response.headers
|
||||
|
||||
@property
|
||||
def http_request(self) -> httpx.Request:
|
||||
"""Returns the httpx Request instance associated with the current response."""
|
||||
return self.http_response.request
|
||||
|
||||
@property
|
||||
def status_code(self) -> int:
|
||||
return self.http_response.status_code
|
||||
|
||||
@property
|
||||
def url(self) -> httpx.URL:
|
||||
"""Returns the URL for which the request was made."""
|
||||
return self.http_response.url
|
||||
|
||||
@property
|
||||
def method(self) -> str:
|
||||
return self.http_request.method
|
||||
|
||||
@property
|
||||
def http_version(self) -> str:
|
||||
return self.http_response.http_version
|
||||
|
||||
@property
|
||||
def elapsed(self) -> datetime.timedelta:
|
||||
"""The time taken for the complete request/response cycle to complete."""
|
||||
return self.http_response.elapsed
|
||||
|
||||
@property
|
||||
def is_closed(self) -> bool:
|
||||
"""Whether or not the response body has been closed.
|
||||
|
||||
If this is False then there is response data that has not been read yet.
|
||||
You must either fully consume the response body or call `.close()`
|
||||
before discarding the response to prevent resource leaks.
|
||||
"""
|
||||
return self.http_response.is_closed
|
||||
|
||||
@override
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"<{self.__class__.__name__} [{self.status_code} {self.http_response.reason_phrase}] type={self._cast_to}>"
|
||||
)
|
||||
|
||||
def _parse(self, *, to: type[_T] | None = None) -> R | _T:
|
||||
cast_to = to if to is not None else self._cast_to
|
||||
|
||||
# unwrap `TypeAlias('Name', T)` -> `T`
|
||||
if is_type_alias_type(cast_to):
|
||||
cast_to = cast_to.__value__ # type: ignore[unreachable]
|
||||
|
||||
# unwrap `Annotated[T, ...]` -> `T`
|
||||
if cast_to and is_annotated_type(cast_to):
|
||||
cast_to = extract_type_arg(cast_to, 0)
|
||||
|
||||
origin = get_origin(cast_to) or cast_to
|
||||
|
||||
if self._is_sse_stream:
|
||||
if to:
|
||||
if not is_stream_class_type(to):
|
||||
raise TypeError(f"Expected custom parse type to be a subclass of {Stream} or {AsyncStream}")
|
||||
|
||||
return cast(
|
||||
_T,
|
||||
to(
|
||||
cast_to=extract_stream_chunk_type(
|
||||
to,
|
||||
failure_message="Expected custom stream type to be passed with a type argument, e.g. Stream[ChunkType]",
|
||||
),
|
||||
response=self.http_response,
|
||||
client=cast(Any, self._client),
|
||||
),
|
||||
)
|
||||
|
||||
if self._stream_cls:
|
||||
return cast(
|
||||
R,
|
||||
self._stream_cls(
|
||||
cast_to=extract_stream_chunk_type(self._stream_cls),
|
||||
response=self.http_response,
|
||||
client=cast(Any, self._client),
|
||||
),
|
||||
)
|
||||
|
||||
stream_cls = cast("type[Stream[Any]] | type[AsyncStream[Any]] | None", self._client._default_stream_cls)
|
||||
if stream_cls is None:
|
||||
raise MissingStreamClassError()
|
||||
|
||||
return cast(
|
||||
R,
|
||||
stream_cls(
|
||||
cast_to=cast_to,
|
||||
response=self.http_response,
|
||||
client=cast(Any, self._client),
|
||||
),
|
||||
)
|
||||
|
||||
if cast_to is NoneType:
|
||||
return cast(R, None)
|
||||
|
||||
response = self.http_response
|
||||
if cast_to == str:
|
||||
return cast(R, response.text)
|
||||
|
||||
if cast_to == bytes:
|
||||
return cast(R, response.content)
|
||||
|
||||
if cast_to == int:
|
||||
return cast(R, int(response.text))
|
||||
|
||||
if cast_to == float:
|
||||
return cast(R, float(response.text))
|
||||
|
||||
if cast_to == bool:
|
||||
return cast(R, response.text.lower() == "true")
|
||||
|
||||
if origin == APIResponse:
|
||||
raise RuntimeError("Unexpected state - cast_to is `APIResponse`")
|
||||
|
||||
if inspect.isclass(origin) and issubclass(origin, httpx.Response):
|
||||
# Because of the invariance of our ResponseT TypeVar, users can subclass httpx.Response
|
||||
# and pass that class to our request functions. We cannot change the variance to be either
|
||||
# covariant or contravariant as that makes our usage of ResponseT illegal. We could construct
|
||||
# the response class ourselves but that is something that should be supported directly in httpx
|
||||
# as it would be easy to incorrectly construct the Response object due to the multitude of arguments.
|
||||
if cast_to != httpx.Response:
|
||||
raise ValueError(f"Subclasses of httpx.Response cannot be passed to `cast_to`")
|
||||
return cast(R, response)
|
||||
|
||||
if (
|
||||
inspect.isclass(
|
||||
origin # pyright: ignore[reportUnknownArgumentType]
|
||||
)
|
||||
and not issubclass(origin, BaseModel)
|
||||
and issubclass(origin, pydantic.BaseModel)
|
||||
):
|
||||
raise TypeError("Pydantic models must subclass our base model type, e.g. `from tinker import BaseModel`")
|
||||
|
||||
if (
|
||||
cast_to is not object
|
||||
and not origin is list
|
||||
and not origin is dict
|
||||
and not origin is Union
|
||||
and not issubclass(origin, BaseModel)
|
||||
):
|
||||
raise RuntimeError(
|
||||
f"Unsupported type, expected {cast_to} to be a subclass of {BaseModel}, {dict}, {list}, {Union}, {NoneType}, {str} or {httpx.Response}."
|
||||
)
|
||||
|
||||
# split is required to handle cases where additional information is included
|
||||
# in the response, e.g. application/json; charset=utf-8
|
||||
content_type, *_ = response.headers.get("content-type", "*").split(";")
|
||||
if not content_type.endswith("json"):
|
||||
if is_basemodel(cast_to):
|
||||
try:
|
||||
data = response.json()
|
||||
except Exception as exc:
|
||||
log.debug("Could not read JSON from response data due to %s - %s", type(exc), exc)
|
||||
else:
|
||||
return self._client._process_response_data(
|
||||
data=data,
|
||||
cast_to=cast_to, # type: ignore
|
||||
response=response,
|
||||
)
|
||||
|
||||
if self._client._strict_response_validation:
|
||||
raise APIResponseValidationError(
|
||||
response=response,
|
||||
message=f"Expected Content-Type response header to be `application/json` but received `{content_type}` instead.",
|
||||
body=response.text,
|
||||
)
|
||||
|
||||
# If the API responds with content that isn't JSON then we just return
|
||||
# the (decoded) text without performing any parsing so that you can still
|
||||
# handle the response however you need to.
|
||||
return response.text # type: ignore
|
||||
|
||||
data = response.json()
|
||||
|
||||
return self._client._process_response_data(
|
||||
data=data,
|
||||
cast_to=cast_to, # type: ignore
|
||||
response=response,
|
||||
)
|
||||
|
||||
|
||||
class APIResponse(BaseAPIResponse[R]):
|
||||
@overload
|
||||
def parse(self, *, to: type[_T]) -> _T: ...
|
||||
|
||||
@overload
|
||||
def parse(self) -> R: ...
|
||||
|
||||
def parse(self, *, to: type[_T] | None = None) -> R | _T:
|
||||
"""Returns the rich python representation of this response's data.
|
||||
|
||||
For lower-level control, see `.read()`, `.json()`, `.iter_bytes()`.
|
||||
|
||||
You can customise the type that the response is parsed into through
|
||||
the `to` argument, e.g.
|
||||
|
||||
```py
|
||||
from tinker import BaseModel
|
||||
|
||||
|
||||
class MyModel(BaseModel):
|
||||
foo: str
|
||||
|
||||
|
||||
obj = response.parse(to=MyModel)
|
||||
print(obj.foo)
|
||||
```
|
||||
|
||||
We support parsing:
|
||||
- `BaseModel`
|
||||
- `dict`
|
||||
- `list`
|
||||
- `Union`
|
||||
- `str`
|
||||
- `int`
|
||||
- `float`
|
||||
- `httpx.Response`
|
||||
"""
|
||||
cache_key = to if to is not None else self._cast_to
|
||||
cached = self._parsed_by_type.get(cache_key)
|
||||
if cached is not None:
|
||||
return cached # type: ignore[no-any-return]
|
||||
|
||||
if not self._is_sse_stream:
|
||||
self.read()
|
||||
|
||||
parsed = self._parse(to=to)
|
||||
if is_given(self._options.post_parser):
|
||||
parsed = self._options.post_parser(parsed)
|
||||
|
||||
self._parsed_by_type[cache_key] = parsed
|
||||
return parsed
|
||||
|
||||
def read(self) -> bytes:
|
||||
"""Read and return the binary response content."""
|
||||
try:
|
||||
return self.http_response.read()
|
||||
except httpx.StreamConsumed as exc:
|
||||
# The default error raised by httpx isn't very
|
||||
# helpful in our case so we re-raise it with
|
||||
# a different error message.
|
||||
raise StreamAlreadyConsumed() from exc
|
||||
|
||||
def text(self) -> str:
|
||||
"""Read and decode the response content into a string."""
|
||||
self.read()
|
||||
return self.http_response.text
|
||||
|
||||
def json(self) -> object:
|
||||
"""Read and decode the JSON response content."""
|
||||
self.read()
|
||||
return self.http_response.json()
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the response and release the connection.
|
||||
|
||||
Automatically called if the response body is read to completion.
|
||||
"""
|
||||
self.http_response.close()
|
||||
|
||||
def iter_bytes(self, chunk_size: int | None = None) -> Iterator[bytes]:
|
||||
"""
|
||||
A byte-iterator over the decoded response content.
|
||||
|
||||
This automatically handles gzip, deflate and brotli encoded responses.
|
||||
"""
|
||||
for chunk in self.http_response.iter_bytes(chunk_size):
|
||||
yield chunk
|
||||
|
||||
def iter_text(self, chunk_size: int | None = None) -> Iterator[str]:
|
||||
"""A str-iterator over the decoded response content
|
||||
that handles both gzip, deflate, etc but also detects the content's
|
||||
string encoding.
|
||||
"""
|
||||
for chunk in self.http_response.iter_text(chunk_size):
|
||||
yield chunk
|
||||
|
||||
def iter_lines(self) -> Iterator[str]:
|
||||
"""Like `iter_text()` but will only yield chunks for each line"""
|
||||
for chunk in self.http_response.iter_lines():
|
||||
yield chunk
|
||||
|
||||
|
||||
class AsyncAPIResponse(BaseAPIResponse[R]):
|
||||
@overload
|
||||
async def parse(self, *, to: type[_T]) -> _T: ...
|
||||
|
||||
@overload
|
||||
async def parse(self) -> R: ...
|
||||
|
||||
async def parse(self, *, to: type[_T] | None = None) -> R | _T:
|
||||
"""Returns the rich python representation of this response's data.
|
||||
|
||||
For lower-level control, see `.read()`, `.json()`, `.iter_bytes()`.
|
||||
|
||||
You can customise the type that the response is parsed into through
|
||||
the `to` argument, e.g.
|
||||
|
||||
```py
|
||||
from tinker import BaseModel
|
||||
|
||||
|
||||
class MyModel(BaseModel):
|
||||
foo: str
|
||||
|
||||
|
||||
obj = response.parse(to=MyModel)
|
||||
print(obj.foo)
|
||||
```
|
||||
|
||||
We support parsing:
|
||||
- `BaseModel`
|
||||
- `dict`
|
||||
- `list`
|
||||
- `Union`
|
||||
- `str`
|
||||
- `httpx.Response`
|
||||
"""
|
||||
cache_key = to if to is not None else self._cast_to
|
||||
cached = self._parsed_by_type.get(cache_key)
|
||||
if cached is not None:
|
||||
return cached # type: ignore[no-any-return]
|
||||
|
||||
if not self._is_sse_stream:
|
||||
await self.read()
|
||||
|
||||
parsed = self._parse(to=to)
|
||||
if is_given(self._options.post_parser):
|
||||
parsed = self._options.post_parser(parsed)
|
||||
|
||||
self._parsed_by_type[cache_key] = parsed
|
||||
return parsed
|
||||
|
||||
async def read(self) -> bytes:
|
||||
"""Read and return the binary response content."""
|
||||
try:
|
||||
return await self.http_response.aread()
|
||||
except httpx.StreamConsumed as exc:
|
||||
# the default error raised by httpx isn't very
|
||||
# helpful in our case so we re-raise it with
|
||||
# a different error message
|
||||
raise StreamAlreadyConsumed() from exc
|
||||
|
||||
async def text(self) -> str:
|
||||
"""Read and decode the response content into a string."""
|
||||
await self.read()
|
||||
return self.http_response.text
|
||||
|
||||
async def json(self) -> object:
|
||||
"""Read and decode the JSON response content."""
|
||||
await self.read()
|
||||
return self.http_response.json()
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the response and release the connection.
|
||||
|
||||
Automatically called if the response body is read to completion.
|
||||
"""
|
||||
await self.http_response.aclose()
|
||||
|
||||
async def iter_bytes(self, chunk_size: int | None = None) -> AsyncIterator[bytes]:
|
||||
"""
|
||||
A byte-iterator over the decoded response content.
|
||||
|
||||
This automatically handles gzip, deflate and brotli encoded responses.
|
||||
"""
|
||||
async for chunk in self.http_response.aiter_bytes(chunk_size):
|
||||
yield chunk
|
||||
|
||||
async def iter_text(self, chunk_size: int | None = None) -> AsyncIterator[str]:
|
||||
"""A str-iterator over the decoded response content
|
||||
that handles both gzip, deflate, etc but also detects the content's
|
||||
string encoding.
|
||||
"""
|
||||
async for chunk in self.http_response.aiter_text(chunk_size):
|
||||
yield chunk
|
||||
|
||||
async def iter_lines(self) -> AsyncIterator[str]:
|
||||
"""Like `iter_text()` but will only yield chunks for each line"""
|
||||
async for chunk in self.http_response.aiter_lines():
|
||||
yield chunk
|
||||
|
||||
|
||||
class BinaryAPIResponse(APIResponse[bytes]):
|
||||
"""Subclass of APIResponse providing helpers for dealing with binary data.
|
||||
|
||||
Note: If you want to stream the response data instead of eagerly reading it
|
||||
all at once then you should use `.with_streaming_response` when making
|
||||
the API request, e.g. `.with_streaming_response.get_binary_response()`
|
||||
"""
|
||||
|
||||
def write_to_file(
|
||||
self,
|
||||
file: str | os.PathLike[str],
|
||||
) -> None:
|
||||
"""Write the output to the given file.
|
||||
|
||||
Accepts a filename or any path-like object, e.g. pathlib.Path
|
||||
|
||||
Note: if you want to stream the data to the file instead of writing
|
||||
all at once then you should use `.with_streaming_response` when making
|
||||
the API request, e.g. `.with_streaming_response.get_binary_response()`
|
||||
"""
|
||||
with open(file, mode="wb") as f:
|
||||
for data in self.iter_bytes():
|
||||
f.write(data)
|
||||
|
||||
|
||||
class AsyncBinaryAPIResponse(AsyncAPIResponse[bytes]):
|
||||
"""Subclass of APIResponse providing helpers for dealing with binary data.
|
||||
|
||||
Note: If you want to stream the response data instead of eagerly reading it
|
||||
all at once then you should use `.with_streaming_response` when making
|
||||
the API request, e.g. `.with_streaming_response.get_binary_response()`
|
||||
"""
|
||||
|
||||
async def write_to_file(
|
||||
self,
|
||||
file: str | os.PathLike[str],
|
||||
) -> None:
|
||||
"""Write the output to the given file.
|
||||
|
||||
Accepts a filename or any path-like object, e.g. pathlib.Path
|
||||
|
||||
Note: if you want to stream the data to the file instead of writing
|
||||
all at once then you should use `.with_streaming_response` when making
|
||||
the API request, e.g. `.with_streaming_response.get_binary_response()`
|
||||
"""
|
||||
path = anyio.Path(file)
|
||||
async with await path.open(mode="wb") as f:
|
||||
async for data in self.iter_bytes():
|
||||
await f.write(data)
|
||||
|
||||
|
||||
class StreamedBinaryAPIResponse(APIResponse[bytes]):
|
||||
def stream_to_file(
|
||||
self,
|
||||
file: str | os.PathLike[str],
|
||||
*,
|
||||
chunk_size: int | None = None,
|
||||
) -> None:
|
||||
"""Streams the output to the given file.
|
||||
|
||||
Accepts a filename or any path-like object, e.g. pathlib.Path
|
||||
"""
|
||||
with open(file, mode="wb") as f:
|
||||
for data in self.iter_bytes(chunk_size):
|
||||
f.write(data)
|
||||
|
||||
|
||||
class AsyncStreamedBinaryAPIResponse(AsyncAPIResponse[bytes]):
|
||||
async def stream_to_file(
|
||||
self,
|
||||
file: str | os.PathLike[str],
|
||||
*,
|
||||
chunk_size: int | None = None,
|
||||
) -> None:
|
||||
"""Streams the output to the given file.
|
||||
|
||||
Accepts a filename or any path-like object, e.g. pathlib.Path
|
||||
"""
|
||||
path = anyio.Path(file)
|
||||
async with await path.open(mode="wb") as f:
|
||||
async for data in self.iter_bytes(chunk_size):
|
||||
await f.write(data)
|
||||
|
||||
|
||||
class MissingStreamClassError(TypeError):
|
||||
def __init__(self) -> None:
|
||||
super().__init__(
|
||||
"The `stream` argument was set to `True` but the `stream_cls` argument was not given. See `tinker._streaming` for reference",
|
||||
)
|
||||
|
||||
|
||||
class StreamAlreadyConsumed(TinkerError):
|
||||
"""
|
||||
Attempted to read or stream content, but the content has already
|
||||
been streamed.
|
||||
|
||||
This can happen if you use a method like `.iter_lines()` and then attempt
|
||||
to read th entire response body afterwards, e.g.
|
||||
|
||||
```py
|
||||
response = await client.post(...)
|
||||
async for line in response.iter_lines():
|
||||
... # do something with `line`
|
||||
|
||||
content = await response.read()
|
||||
# ^ error
|
||||
```
|
||||
|
||||
If you want this behaviour you'll need to either manually accumulate the response
|
||||
content or call `await response.read()` before iterating over the stream.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
message = (
|
||||
"Attempted to read or stream some content, but the content has "
|
||||
"already been streamed. "
|
||||
"This could be due to attempting to stream the response "
|
||||
"content more than once."
|
||||
"\n\n"
|
||||
"You can fix this by manually accumulating the response content while streaming "
|
||||
"or by calling `.read()` before starting to stream."
|
||||
)
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class ResponseContextManager(Generic[_APIResponseT]):
|
||||
"""Context manager for ensuring that a request is not made
|
||||
until it is entered and that the response will always be closed
|
||||
when the context manager exits
|
||||
"""
|
||||
|
||||
def __init__(self, request_func: Callable[[], _APIResponseT]) -> None:
|
||||
self._request_func = request_func
|
||||
self.__response: _APIResponseT | None = None
|
||||
|
||||
def __enter__(self) -> _APIResponseT:
|
||||
self.__response = self._request_func()
|
||||
return self.__response
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc: BaseException | None,
|
||||
exc_tb: TracebackType | None,
|
||||
) -> None:
|
||||
if self.__response is not None:
|
||||
self.__response.close()
|
||||
|
||||
|
||||
class AsyncResponseContextManager(Generic[_AsyncAPIResponseT]):
|
||||
"""Context manager for ensuring that a request is not made
|
||||
until it is entered and that the response will always be closed
|
||||
when the context manager exits
|
||||
"""
|
||||
|
||||
def __init__(self, api_request: Awaitable[_AsyncAPIResponseT]) -> None:
|
||||
self._api_request = api_request
|
||||
self.__response: _AsyncAPIResponseT | None = None
|
||||
|
||||
async def __aenter__(self) -> _AsyncAPIResponseT:
|
||||
self.__response = await self._api_request
|
||||
return self.__response
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc: BaseException | None,
|
||||
exc_tb: TracebackType | None,
|
||||
) -> None:
|
||||
if self.__response is not None:
|
||||
await self.__response.close()
|
||||
|
||||
|
||||
def to_streamed_response_wrapper(func: Callable[P, R]) -> Callable[P, ResponseContextManager[APIResponse[R]]]:
|
||||
"""Higher order function that takes one of our bound API methods and wraps it
|
||||
to support streaming and returning the raw `APIResponse` object directly.
|
||||
"""
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapped(*args: P.args, **kwargs: P.kwargs) -> ResponseContextManager[APIResponse[R]]:
|
||||
extra_headers: dict[str, str] = {**(cast(Any, kwargs.get("extra_headers")) or {})}
|
||||
extra_headers[RAW_RESPONSE_HEADER] = "stream"
|
||||
|
||||
kwargs["extra_headers"] = extra_headers
|
||||
|
||||
make_request = functools.partial(func, *args, **kwargs)
|
||||
|
||||
return ResponseContextManager(cast(Callable[[], APIResponse[R]], make_request))
|
||||
|
||||
return wrapped
|
||||
|
||||
|
||||
def async_to_streamed_response_wrapper(
|
||||
func: Callable[P, Awaitable[R]],
|
||||
) -> Callable[P, AsyncResponseContextManager[AsyncAPIResponse[R]]]:
|
||||
"""Higher order function that takes one of our bound API methods and wraps it
|
||||
to support streaming and returning the raw `APIResponse` object directly.
|
||||
"""
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapped(*args: P.args, **kwargs: P.kwargs) -> AsyncResponseContextManager[AsyncAPIResponse[R]]:
|
||||
extra_headers: dict[str, str] = {**(cast(Any, kwargs.get("extra_headers")) or {})}
|
||||
extra_headers[RAW_RESPONSE_HEADER] = "stream"
|
||||
|
||||
kwargs["extra_headers"] = extra_headers
|
||||
|
||||
make_request = func(*args, **kwargs)
|
||||
|
||||
return AsyncResponseContextManager(cast(Awaitable[AsyncAPIResponse[R]], make_request))
|
||||
|
||||
return wrapped
|
||||
|
||||
|
||||
def to_custom_streamed_response_wrapper(
|
||||
func: Callable[P, object],
|
||||
response_cls: type[_APIResponseT],
|
||||
) -> Callable[P, ResponseContextManager[_APIResponseT]]:
|
||||
"""Higher order function that takes one of our bound API methods and an `APIResponse` class
|
||||
and wraps the method to support streaming and returning the given response class directly.
|
||||
|
||||
Note: the given `response_cls` *must* be concrete, e.g. `class BinaryAPIResponse(APIResponse[bytes])`
|
||||
"""
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapped(*args: P.args, **kwargs: P.kwargs) -> ResponseContextManager[_APIResponseT]:
|
||||
extra_headers: dict[str, Any] = {**(cast(Any, kwargs.get("extra_headers")) or {})}
|
||||
extra_headers[RAW_RESPONSE_HEADER] = "stream"
|
||||
extra_headers[OVERRIDE_CAST_TO_HEADER] = response_cls
|
||||
|
||||
kwargs["extra_headers"] = extra_headers
|
||||
|
||||
make_request = functools.partial(func, *args, **kwargs)
|
||||
|
||||
return ResponseContextManager(cast(Callable[[], _APIResponseT], make_request))
|
||||
|
||||
return wrapped
|
||||
|
||||
|
||||
def async_to_custom_streamed_response_wrapper(
|
||||
func: Callable[P, Awaitable[object]],
|
||||
response_cls: type[_AsyncAPIResponseT],
|
||||
) -> Callable[P, AsyncResponseContextManager[_AsyncAPIResponseT]]:
|
||||
"""Higher order function that takes one of our bound API methods and an `APIResponse` class
|
||||
and wraps the method to support streaming and returning the given response class directly.
|
||||
|
||||
Note: the given `response_cls` *must* be concrete, e.g. `class BinaryAPIResponse(APIResponse[bytes])`
|
||||
"""
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapped(*args: P.args, **kwargs: P.kwargs) -> AsyncResponseContextManager[_AsyncAPIResponseT]:
|
||||
extra_headers: dict[str, Any] = {**(cast(Any, kwargs.get("extra_headers")) or {})}
|
||||
extra_headers[RAW_RESPONSE_HEADER] = "stream"
|
||||
extra_headers[OVERRIDE_CAST_TO_HEADER] = response_cls
|
||||
|
||||
kwargs["extra_headers"] = extra_headers
|
||||
|
||||
make_request = func(*args, **kwargs)
|
||||
|
||||
return AsyncResponseContextManager(cast(Awaitable[_AsyncAPIResponseT], make_request))
|
||||
|
||||
return wrapped
|
||||
|
||||
|
||||
def to_raw_response_wrapper(func: Callable[P, R]) -> Callable[P, APIResponse[R]]:
|
||||
"""Higher order function that takes one of our bound API methods and wraps it
|
||||
to support returning the raw `APIResponse` object directly.
|
||||
"""
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapped(*args: P.args, **kwargs: P.kwargs) -> APIResponse[R]:
|
||||
extra_headers: dict[str, str] = {**(cast(Any, kwargs.get("extra_headers")) or {})}
|
||||
extra_headers[RAW_RESPONSE_HEADER] = "raw"
|
||||
|
||||
kwargs["extra_headers"] = extra_headers
|
||||
|
||||
return cast(APIResponse[R], func(*args, **kwargs))
|
||||
|
||||
return wrapped
|
||||
|
||||
|
||||
def async_to_raw_response_wrapper(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[AsyncAPIResponse[R]]]:
|
||||
"""Higher order function that takes one of our bound API methods and wraps it
|
||||
to support returning the raw `APIResponse` object directly.
|
||||
"""
|
||||
|
||||
@functools.wraps(func)
|
||||
async def wrapped(*args: P.args, **kwargs: P.kwargs) -> AsyncAPIResponse[R]:
|
||||
extra_headers: dict[str, str] = {**(cast(Any, kwargs.get("extra_headers")) or {})}
|
||||
extra_headers[RAW_RESPONSE_HEADER] = "raw"
|
||||
|
||||
kwargs["extra_headers"] = extra_headers
|
||||
|
||||
return cast(AsyncAPIResponse[R], await func(*args, **kwargs))
|
||||
|
||||
return wrapped
|
||||
|
||||
|
||||
def to_custom_raw_response_wrapper(
|
||||
func: Callable[P, object],
|
||||
response_cls: type[_APIResponseT],
|
||||
) -> Callable[P, _APIResponseT]:
|
||||
"""Higher order function that takes one of our bound API methods and an `APIResponse` class
|
||||
and wraps the method to support returning the given response class directly.
|
||||
|
||||
Note: the given `response_cls` *must* be concrete, e.g. `class BinaryAPIResponse(APIResponse[bytes])`
|
||||
"""
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapped(*args: P.args, **kwargs: P.kwargs) -> _APIResponseT:
|
||||
extra_headers: dict[str, Any] = {**(cast(Any, kwargs.get("extra_headers")) or {})}
|
||||
extra_headers[RAW_RESPONSE_HEADER] = "raw"
|
||||
extra_headers[OVERRIDE_CAST_TO_HEADER] = response_cls
|
||||
|
||||
kwargs["extra_headers"] = extra_headers
|
||||
|
||||
return cast(_APIResponseT, func(*args, **kwargs))
|
||||
|
||||
return wrapped
|
||||
|
||||
|
||||
def async_to_custom_raw_response_wrapper(
|
||||
func: Callable[P, Awaitable[object]],
|
||||
response_cls: type[_AsyncAPIResponseT],
|
||||
) -> Callable[P, Awaitable[_AsyncAPIResponseT]]:
|
||||
"""Higher order function that takes one of our bound API methods and an `APIResponse` class
|
||||
and wraps the method to support returning the given response class directly.
|
||||
|
||||
Note: the given `response_cls` *must* be concrete, e.g. `class BinaryAPIResponse(APIResponse[bytes])`
|
||||
"""
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapped(*args: P.args, **kwargs: P.kwargs) -> Awaitable[_AsyncAPIResponseT]:
|
||||
extra_headers: dict[str, Any] = {**(cast(Any, kwargs.get("extra_headers")) or {})}
|
||||
extra_headers[RAW_RESPONSE_HEADER] = "raw"
|
||||
extra_headers[OVERRIDE_CAST_TO_HEADER] = response_cls
|
||||
|
||||
kwargs["extra_headers"] = extra_headers
|
||||
|
||||
return cast(Awaitable[_AsyncAPIResponseT], func(*args, **kwargs))
|
||||
|
||||
return wrapped
|
||||
|
||||
|
||||
def extract_response_type(typ: type[BaseAPIResponse[Any]]) -> type:
|
||||
"""Given a type like `APIResponse[T]`, returns the generic type variable `T`.
|
||||
|
||||
This also handles the case where a concrete subclass is given, e.g.
|
||||
```py
|
||||
class MyResponse(APIResponse[bytes]):
|
||||
...
|
||||
|
||||
extract_response_type(MyResponse) -> bytes
|
||||
```
|
||||
"""
|
||||
return extract_type_var_from_base(
|
||||
typ,
|
||||
generic_bases=cast("tuple[type, ...]", (BaseAPIResponse, APIResponse, AsyncAPIResponse)),
|
||||
index=0,
|
||||
)
|
||||
333
src/tinker/_streaming.py
Normal file
333
src/tinker/_streaming.py
Normal file
|
|
@ -0,0 +1,333 @@
|
|||
# Note: initially copied from https://github.com/florimondmanca/httpx-sse/blob/master/src/httpx_sse/_decoders.py
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import inspect
|
||||
from types import TracebackType
|
||||
from typing import TYPE_CHECKING, Any, Generic, TypeVar, Iterator, AsyncIterator, cast
|
||||
from typing_extensions import Self, Protocol, TypeGuard, override, get_origin, runtime_checkable
|
||||
|
||||
import httpx
|
||||
|
||||
from ._utils import extract_type_var_from_base
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ._client import Tinker, AsyncTinker
|
||||
|
||||
|
||||
_T = TypeVar("_T")
|
||||
|
||||
|
||||
class Stream(Generic[_T]):
|
||||
"""Provides the core interface to iterate over a synchronous stream response."""
|
||||
|
||||
response: httpx.Response
|
||||
|
||||
_decoder: SSEBytesDecoder
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
cast_to: type[_T],
|
||||
response: httpx.Response,
|
||||
client: Tinker,
|
||||
) -> None:
|
||||
self.response = response
|
||||
self._cast_to = cast_to
|
||||
self._client = client
|
||||
self._decoder = client._make_sse_decoder()
|
||||
self._iterator = self.__stream__()
|
||||
|
||||
def __next__(self) -> _T:
|
||||
return self._iterator.__next__()
|
||||
|
||||
def __iter__(self) -> Iterator[_T]:
|
||||
for item in self._iterator:
|
||||
yield item
|
||||
|
||||
def _iter_events(self) -> Iterator[ServerSentEvent]:
|
||||
yield from self._decoder.iter_bytes(self.response.iter_bytes())
|
||||
|
||||
def __stream__(self) -> Iterator[_T]:
|
||||
cast_to = cast(Any, self._cast_to)
|
||||
response = self.response
|
||||
process_data = self._client._process_response_data
|
||||
iterator = self._iter_events()
|
||||
|
||||
for sse in iterator:
|
||||
yield process_data(data=sse.json(), cast_to=cast_to, response=response)
|
||||
|
||||
# Ensure the entire stream is consumed
|
||||
for _sse in iterator:
|
||||
...
|
||||
|
||||
def __enter__(self) -> Self:
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc: BaseException | None,
|
||||
exc_tb: TracebackType | None,
|
||||
) -> None:
|
||||
self.close()
|
||||
|
||||
def close(self) -> None:
|
||||
"""
|
||||
Close the response and release the connection.
|
||||
|
||||
Automatically called if the response body is read to completion.
|
||||
"""
|
||||
self.response.close()
|
||||
|
||||
|
||||
class AsyncStream(Generic[_T]):
|
||||
"""Provides the core interface to iterate over an asynchronous stream response."""
|
||||
|
||||
response: httpx.Response
|
||||
|
||||
_decoder: SSEDecoder | SSEBytesDecoder
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
cast_to: type[_T],
|
||||
response: httpx.Response,
|
||||
client: AsyncTinker,
|
||||
) -> None:
|
||||
self.response = response
|
||||
self._cast_to = cast_to
|
||||
self._client = client
|
||||
self._decoder = client._make_sse_decoder()
|
||||
self._iterator = self.__stream__()
|
||||
|
||||
async def __anext__(self) -> _T:
|
||||
return await self._iterator.__anext__()
|
||||
|
||||
async def __aiter__(self) -> AsyncIterator[_T]:
|
||||
async for item in self._iterator:
|
||||
yield item
|
||||
|
||||
async def _iter_events(self) -> AsyncIterator[ServerSentEvent]:
|
||||
async for sse in self._decoder.aiter_bytes(self.response.aiter_bytes()):
|
||||
yield sse
|
||||
|
||||
async def __stream__(self) -> AsyncIterator[_T]:
|
||||
cast_to = cast(Any, self._cast_to)
|
||||
response = self.response
|
||||
process_data = self._client._process_response_data
|
||||
iterator = self._iter_events()
|
||||
|
||||
async for sse in iterator:
|
||||
yield process_data(data=sse.json(), cast_to=cast_to, response=response)
|
||||
|
||||
# Ensure the entire stream is consumed
|
||||
async for _sse in iterator:
|
||||
...
|
||||
|
||||
async def __aenter__(self) -> Self:
|
||||
return self
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc: BaseException | None,
|
||||
exc_tb: TracebackType | None,
|
||||
) -> None:
|
||||
await self.close()
|
||||
|
||||
async def close(self) -> None:
|
||||
"""
|
||||
Close the response and release the connection.
|
||||
|
||||
Automatically called if the response body is read to completion.
|
||||
"""
|
||||
await self.response.aclose()
|
||||
|
||||
|
||||
class ServerSentEvent:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
event: str | None = None,
|
||||
data: str | None = None,
|
||||
id: str | None = None,
|
||||
retry: int | None = None,
|
||||
) -> None:
|
||||
if data is None:
|
||||
data = ""
|
||||
|
||||
self._id = id
|
||||
self._data = data
|
||||
self._event = event or None
|
||||
self._retry = retry
|
||||
|
||||
@property
|
||||
def event(self) -> str | None:
|
||||
return self._event
|
||||
|
||||
@property
|
||||
def id(self) -> str | None:
|
||||
return self._id
|
||||
|
||||
@property
|
||||
def retry(self) -> int | None:
|
||||
return self._retry
|
||||
|
||||
@property
|
||||
def data(self) -> str:
|
||||
return self._data
|
||||
|
||||
def json(self) -> Any:
|
||||
return json.loads(self.data)
|
||||
|
||||
@override
|
||||
def __repr__(self) -> str:
|
||||
return f"ServerSentEvent(event={self.event}, data={self.data}, id={self.id}, retry={self.retry})"
|
||||
|
||||
|
||||
class SSEDecoder:
|
||||
_data: list[str]
|
||||
_event: str | None
|
||||
_retry: int | None
|
||||
_last_event_id: str | None
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._event = None
|
||||
self._data = []
|
||||
self._last_event_id = None
|
||||
self._retry = None
|
||||
|
||||
def iter_bytes(self, iterator: Iterator[bytes]) -> Iterator[ServerSentEvent]:
|
||||
"""Given an iterator that yields raw binary data, iterate over it & yield every event encountered"""
|
||||
for chunk in self._iter_chunks(iterator):
|
||||
# Split before decoding so splitlines() only uses \r and \n
|
||||
for raw_line in chunk.splitlines():
|
||||
line = raw_line.decode("utf-8")
|
||||
sse = self.decode(line)
|
||||
if sse:
|
||||
yield sse
|
||||
|
||||
def _iter_chunks(self, iterator: Iterator[bytes]) -> Iterator[bytes]:
|
||||
"""Given an iterator that yields raw binary data, iterate over it and yield individual SSE chunks"""
|
||||
data = b""
|
||||
for chunk in iterator:
|
||||
for line in chunk.splitlines(keepends=True):
|
||||
data += line
|
||||
if data.endswith((b"\r\r", b"\n\n", b"\r\n\r\n")):
|
||||
yield data
|
||||
data = b""
|
||||
if data:
|
||||
yield data
|
||||
|
||||
async def aiter_bytes(self, iterator: AsyncIterator[bytes]) -> AsyncIterator[ServerSentEvent]:
|
||||
"""Given an iterator that yields raw binary data, iterate over it & yield every event encountered"""
|
||||
async for chunk in self._aiter_chunks(iterator):
|
||||
# Split before decoding so splitlines() only uses \r and \n
|
||||
for raw_line in chunk.splitlines():
|
||||
line = raw_line.decode("utf-8")
|
||||
sse = self.decode(line)
|
||||
if sse:
|
||||
yield sse
|
||||
|
||||
async def _aiter_chunks(self, iterator: AsyncIterator[bytes]) -> AsyncIterator[bytes]:
|
||||
"""Given an iterator that yields raw binary data, iterate over it and yield individual SSE chunks"""
|
||||
data = b""
|
||||
async for chunk in iterator:
|
||||
for line in chunk.splitlines(keepends=True):
|
||||
data += line
|
||||
if data.endswith((b"\r\r", b"\n\n", b"\r\n\r\n")):
|
||||
yield data
|
||||
data = b""
|
||||
if data:
|
||||
yield data
|
||||
|
||||
def decode(self, line: str) -> ServerSentEvent | None:
|
||||
# See: https://html.spec.whatwg.org/multipage/server-sent-events.html#event-stream-interpretation # noqa: E501
|
||||
|
||||
if not line:
|
||||
if not self._event and not self._data and not self._last_event_id and self._retry is None:
|
||||
return None
|
||||
|
||||
sse = ServerSentEvent(
|
||||
event=self._event,
|
||||
data="\n".join(self._data),
|
||||
id=self._last_event_id,
|
||||
retry=self._retry,
|
||||
)
|
||||
|
||||
# NOTE: as per the SSE spec, do not reset last_event_id.
|
||||
self._event = None
|
||||
self._data = []
|
||||
self._retry = None
|
||||
|
||||
return sse
|
||||
|
||||
if line.startswith(":"):
|
||||
return None
|
||||
|
||||
fieldname, _, value = line.partition(":")
|
||||
|
||||
if value.startswith(" "):
|
||||
value = value[1:]
|
||||
|
||||
if fieldname == "event":
|
||||
self._event = value
|
||||
elif fieldname == "data":
|
||||
self._data.append(value)
|
||||
elif fieldname == "id":
|
||||
if "\0" in value:
|
||||
pass
|
||||
else:
|
||||
self._last_event_id = value
|
||||
elif fieldname == "retry":
|
||||
try:
|
||||
self._retry = int(value)
|
||||
except (TypeError, ValueError):
|
||||
pass
|
||||
else:
|
||||
pass # Field is ignored.
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class SSEBytesDecoder(Protocol):
|
||||
def iter_bytes(self, iterator: Iterator[bytes]) -> Iterator[ServerSentEvent]:
|
||||
"""Given an iterator that yields raw binary data, iterate over it & yield every event encountered"""
|
||||
...
|
||||
|
||||
def aiter_bytes(self, iterator: AsyncIterator[bytes]) -> AsyncIterator[ServerSentEvent]:
|
||||
"""Given an async iterator that yields raw binary data, iterate over it & yield every event encountered"""
|
||||
...
|
||||
|
||||
|
||||
def is_stream_class_type(typ: type) -> TypeGuard[type[Stream[object]] | type[AsyncStream[object]]]:
|
||||
"""TypeGuard for determining whether or not the given type is a subclass of `Stream` / `AsyncStream`"""
|
||||
origin = get_origin(typ) or typ
|
||||
return inspect.isclass(origin) and issubclass(origin, (Stream, AsyncStream))
|
||||
|
||||
|
||||
def extract_stream_chunk_type(
|
||||
stream_cls: type,
|
||||
*,
|
||||
failure_message: str | None = None,
|
||||
) -> type:
|
||||
"""Given a type like `Stream[T]`, returns the generic type variable `T`.
|
||||
|
||||
This also handles the case where a concrete subclass is given, e.g.
|
||||
```py
|
||||
class MyStream(Stream[bytes]):
|
||||
...
|
||||
|
||||
extract_stream_chunk_type(MyStream) -> bytes
|
||||
```
|
||||
"""
|
||||
from ._base_client import Stream, AsyncStream
|
||||
|
||||
return extract_type_var_from_base(
|
||||
stream_cls,
|
||||
index=0,
|
||||
generic_bases=cast("tuple[type, ...]", (Stream, AsyncStream)),
|
||||
failure_message=failure_message,
|
||||
)
|
||||
219
src/tinker/_types.py
Normal file
219
src/tinker/_types.py
Normal file
|
|
@ -0,0 +1,219 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from os import PathLike
|
||||
from typing import (
|
||||
IO,
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Dict,
|
||||
List,
|
||||
Type,
|
||||
Tuple,
|
||||
Union,
|
||||
Mapping,
|
||||
TypeVar,
|
||||
Callable,
|
||||
Optional,
|
||||
Sequence,
|
||||
)
|
||||
from typing_extensions import Set, Literal, Protocol, TypeAlias, TypedDict, override, runtime_checkable
|
||||
|
||||
import httpx
|
||||
import pydantic
|
||||
from httpx import URL, Proxy, Timeout, Response, BaseTransport, AsyncBaseTransport
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ._models import BaseModel
|
||||
from ._response import APIResponse, AsyncAPIResponse
|
||||
|
||||
Transport = BaseTransport
|
||||
AsyncTransport = AsyncBaseTransport
|
||||
Query = Mapping[str, object]
|
||||
Body = object
|
||||
AnyMapping = Mapping[str, object]
|
||||
ModelT = TypeVar("ModelT", bound=pydantic.BaseModel)
|
||||
_T = TypeVar("_T")
|
||||
|
||||
|
||||
# Approximates httpx internal ProxiesTypes and RequestFiles types
|
||||
# while adding support for `PathLike` instances
|
||||
ProxiesDict = Dict["str | URL", Union[None, str, URL, Proxy]]
|
||||
ProxiesTypes = Union[str, Proxy, ProxiesDict]
|
||||
if TYPE_CHECKING:
|
||||
Base64FileInput = Union[IO[bytes], PathLike[str]]
|
||||
FileContent = Union[IO[bytes], bytes, PathLike[str]]
|
||||
else:
|
||||
Base64FileInput = Union[IO[bytes], PathLike]
|
||||
FileContent = Union[IO[bytes], bytes, PathLike] # PathLike is not subscriptable in Python 3.8.
|
||||
FileTypes = Union[
|
||||
# file (or bytes)
|
||||
FileContent,
|
||||
# (filename, file (or bytes))
|
||||
Tuple[Optional[str], FileContent],
|
||||
# (filename, file (or bytes), content_type)
|
||||
Tuple[Optional[str], FileContent, Optional[str]],
|
||||
# (filename, file (or bytes), content_type, headers)
|
||||
Tuple[Optional[str], FileContent, Optional[str], Mapping[str, str]],
|
||||
]
|
||||
RequestFiles = Union[Mapping[str, FileTypes], Sequence[Tuple[str, FileTypes]]]
|
||||
|
||||
# duplicate of the above but without our custom file support
|
||||
HttpxFileContent = Union[IO[bytes], bytes]
|
||||
HttpxFileTypes = Union[
|
||||
# file (or bytes)
|
||||
HttpxFileContent,
|
||||
# (filename, file (or bytes))
|
||||
Tuple[Optional[str], HttpxFileContent],
|
||||
# (filename, file (or bytes), content_type)
|
||||
Tuple[Optional[str], HttpxFileContent, Optional[str]],
|
||||
# (filename, file (or bytes), content_type, headers)
|
||||
Tuple[Optional[str], HttpxFileContent, Optional[str], Mapping[str, str]],
|
||||
]
|
||||
HttpxRequestFiles = Union[Mapping[str, HttpxFileTypes], Sequence[Tuple[str, HttpxFileTypes]]]
|
||||
|
||||
# Workaround to support (cast_to: Type[ResponseT]) -> ResponseT
|
||||
# where ResponseT includes `None`. In order to support directly
|
||||
# passing `None`, overloads would have to be defined for every
|
||||
# method that uses `ResponseT` which would lead to an unacceptable
|
||||
# amount of code duplication and make it unreadable. See _base_client.py
|
||||
# for example usage.
|
||||
#
|
||||
# This unfortunately means that you will either have
|
||||
# to import this type and pass it explicitly:
|
||||
#
|
||||
# from tinker import NoneType
|
||||
# client.get('/foo', cast_to=NoneType)
|
||||
#
|
||||
# or build it yourself:
|
||||
#
|
||||
# client.get('/foo', cast_to=type(None))
|
||||
if TYPE_CHECKING:
|
||||
NoneType: Type[None]
|
||||
else:
|
||||
NoneType = type(None)
|
||||
|
||||
|
||||
class RequestOptions(TypedDict, total=False):
|
||||
headers: Headers
|
||||
max_retries: int
|
||||
timeout: float | Timeout | None
|
||||
params: Query
|
||||
extra_json: AnyMapping
|
||||
idempotency_key: str
|
||||
follow_redirects: bool
|
||||
|
||||
|
||||
# Sentinel class used until PEP 0661 is accepted
|
||||
class NotGiven:
|
||||
"""
|
||||
A sentinel singleton class used to distinguish omitted keyword arguments
|
||||
from those passed in with the value None (which may have different behavior).
|
||||
|
||||
For example:
|
||||
|
||||
```py
|
||||
def get(timeout: Union[int, NotGiven, None] = NotGiven()) -> Response: ...
|
||||
|
||||
|
||||
get(timeout=1) # 1s timeout
|
||||
get(timeout=None) # No timeout
|
||||
get() # Default timeout behavior, which may not be statically known at the method definition.
|
||||
```
|
||||
"""
|
||||
|
||||
def __bool__(self) -> Literal[False]:
|
||||
return False
|
||||
|
||||
@override
|
||||
def __repr__(self) -> str:
|
||||
return "NOT_GIVEN"
|
||||
|
||||
|
||||
NotGivenOr = Union[_T, NotGiven]
|
||||
NOT_GIVEN = NotGiven()
|
||||
|
||||
|
||||
class Omit:
|
||||
"""In certain situations you need to be able to represent a case where a default value has
|
||||
to be explicitly removed and `None` is not an appropriate substitute, for example:
|
||||
|
||||
```py
|
||||
# as the default `Content-Type` header is `application/json` that will be sent
|
||||
client.post("/upload/files", files={"file": b"my raw file content"})
|
||||
|
||||
# you can't explicitly override the header as it has to be dynamically generated
|
||||
# to look something like: 'multipart/form-data; boundary=0d8382fcf5f8c3be01ca2e11002d2983'
|
||||
client.post(..., headers={"Content-Type": "multipart/form-data"})
|
||||
|
||||
# instead you can remove the default `application/json` header by passing Omit
|
||||
client.post(..., headers={"Content-Type": Omit()})
|
||||
```
|
||||
"""
|
||||
|
||||
def __bool__(self) -> Literal[False]:
|
||||
return False
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class ModelBuilderProtocol(Protocol):
|
||||
@classmethod
|
||||
def build(
|
||||
cls: type[_T],
|
||||
*,
|
||||
response: Response,
|
||||
data: object,
|
||||
) -> _T: ...
|
||||
|
||||
|
||||
Headers = Mapping[str, Union[str, Omit]]
|
||||
|
||||
|
||||
class HeadersLikeProtocol(Protocol):
|
||||
def get(self, __key: str) -> str | None: ...
|
||||
|
||||
|
||||
HeadersLike = Union[Headers, HeadersLikeProtocol]
|
||||
|
||||
ResponseT = TypeVar(
|
||||
"ResponseT",
|
||||
bound=Union[
|
||||
object,
|
||||
str,
|
||||
None,
|
||||
"BaseModel",
|
||||
List[Any],
|
||||
Dict[str, Any],
|
||||
Response,
|
||||
ModelBuilderProtocol,
|
||||
"APIResponse[Any]",
|
||||
"AsyncAPIResponse[Any]",
|
||||
],
|
||||
)
|
||||
|
||||
StrBytesIntFloat = Union[str, bytes, int, float]
|
||||
|
||||
# Note: copied from Pydantic
|
||||
# https://github.com/pydantic/pydantic/blob/6f31f8f68ef011f84357330186f603ff295312fd/pydantic/main.py#L79
|
||||
IncEx: TypeAlias = Union[Set[int], Set[str], Mapping[int, Union["IncEx", bool]], Mapping[str, Union["IncEx", bool]]]
|
||||
|
||||
PostParser = Callable[[Any], Any]
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class InheritsGeneric(Protocol):
|
||||
"""Represents a type that has inherited from `Generic`
|
||||
|
||||
The `__orig_bases__` property can be used to determine the resolved
|
||||
type variable for a given base class.
|
||||
"""
|
||||
|
||||
__orig_bases__: tuple[_GenericAlias]
|
||||
|
||||
|
||||
class _GenericAlias(Protocol):
|
||||
__origin__: type[object]
|
||||
|
||||
|
||||
class HttpxSendArgs(TypedDict, total=False):
|
||||
auth: httpx.Auth
|
||||
follow_redirects: bool
|
||||
57
src/tinker/_utils/__init__.py
Normal file
57
src/tinker/_utils/__init__.py
Normal file
|
|
@ -0,0 +1,57 @@
|
|||
from ._sync import asyncify as asyncify
|
||||
from ._proxy import LazyProxy as LazyProxy
|
||||
from ._utils import (
|
||||
flatten as flatten,
|
||||
is_dict as is_dict,
|
||||
is_list as is_list,
|
||||
is_given as is_given,
|
||||
is_tuple as is_tuple,
|
||||
json_safe as json_safe,
|
||||
lru_cache as lru_cache,
|
||||
is_mapping as is_mapping,
|
||||
is_tuple_t as is_tuple_t,
|
||||
parse_date as parse_date,
|
||||
is_iterable as is_iterable,
|
||||
is_sequence as is_sequence,
|
||||
coerce_float as coerce_float,
|
||||
is_mapping_t as is_mapping_t,
|
||||
removeprefix as removeprefix,
|
||||
removesuffix as removesuffix,
|
||||
extract_files as extract_files,
|
||||
is_sequence_t as is_sequence_t,
|
||||
required_args as required_args,
|
||||
coerce_boolean as coerce_boolean,
|
||||
coerce_integer as coerce_integer,
|
||||
file_from_path as file_from_path,
|
||||
parse_datetime as parse_datetime,
|
||||
strip_not_given as strip_not_given,
|
||||
deepcopy_minimal as deepcopy_minimal,
|
||||
get_async_library as get_async_library,
|
||||
maybe_coerce_float as maybe_coerce_float,
|
||||
get_required_header as get_required_header,
|
||||
maybe_coerce_boolean as maybe_coerce_boolean,
|
||||
maybe_coerce_integer as maybe_coerce_integer,
|
||||
)
|
||||
from ._typing import (
|
||||
is_list_type as is_list_type,
|
||||
is_union_type as is_union_type,
|
||||
extract_type_arg as extract_type_arg,
|
||||
is_iterable_type as is_iterable_type,
|
||||
is_required_type as is_required_type,
|
||||
is_annotated_type as is_annotated_type,
|
||||
is_type_alias_type as is_type_alias_type,
|
||||
strip_annotated_type as strip_annotated_type,
|
||||
extract_type_var_from_base as extract_type_var_from_base,
|
||||
)
|
||||
from ._streams import consume_sync_iterator as consume_sync_iterator, consume_async_iterator as consume_async_iterator
|
||||
from ._transform import (
|
||||
PropertyInfo as PropertyInfo,
|
||||
transform as transform,
|
||||
async_transform as async_transform,
|
||||
maybe_transform as maybe_transform,
|
||||
async_maybe_transform as async_maybe_transform,
|
||||
)
|
||||
from ._reflection import (
|
||||
function_has_argument as function_has_argument,
|
||||
assert_signatures_in_sync as assert_signatures_in_sync,
|
||||
)
|
||||
25
src/tinker/_utils/_logs.py
Normal file
25
src/tinker/_utils/_logs.py
Normal file
|
|
@ -0,0 +1,25 @@
|
|||
import os
|
||||
import logging
|
||||
|
||||
logger: logging.Logger = logging.getLogger("tinker")
|
||||
httpx_logger: logging.Logger = logging.getLogger("httpx")
|
||||
|
||||
|
||||
def _basic_config() -> None:
|
||||
# e.g. [2023-10-05 14:12:26 - tinker._base_client:818 - DEBUG] HTTP Request: POST http://127.0.0.1:4010/foo/bar "200 OK"
|
||||
logging.basicConfig(
|
||||
format="[%(asctime)s - %(name)s:%(lineno)d - %(levelname)s] %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
|
||||
|
||||
def setup_logging() -> None:
|
||||
env = os.environ.get("TINKER_LOG")
|
||||
if env == "debug":
|
||||
_basic_config()
|
||||
logger.setLevel(logging.DEBUG)
|
||||
httpx_logger.setLevel(logging.DEBUG)
|
||||
elif env == "info":
|
||||
_basic_config()
|
||||
logger.setLevel(logging.INFO)
|
||||
httpx_logger.setLevel(logging.INFO)
|
||||
65
src/tinker/_utils/_proxy.py
Normal file
65
src/tinker/_utils/_proxy.py
Normal file
|
|
@ -0,0 +1,65 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Generic, TypeVar, Iterable, cast
|
||||
from typing_extensions import override
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class LazyProxy(Generic[T], ABC):
|
||||
"""Implements data methods to pretend that an instance is another instance.
|
||||
|
||||
This includes forwarding attribute access and other methods.
|
||||
"""
|
||||
|
||||
# Note: we have to special case proxies that themselves return proxies
|
||||
# to support using a proxy as a catch-all for any random access, e.g. `proxy.foo.bar.baz`
|
||||
|
||||
def __getattr__(self, attr: str) -> object:
|
||||
proxied = self.__get_proxied__()
|
||||
if isinstance(proxied, LazyProxy):
|
||||
return proxied # pyright: ignore
|
||||
return getattr(proxied, attr)
|
||||
|
||||
@override
|
||||
def __repr__(self) -> str:
|
||||
proxied = self.__get_proxied__()
|
||||
if isinstance(proxied, LazyProxy):
|
||||
return proxied.__class__.__name__
|
||||
return repr(self.__get_proxied__())
|
||||
|
||||
@override
|
||||
def __str__(self) -> str:
|
||||
proxied = self.__get_proxied__()
|
||||
if isinstance(proxied, LazyProxy):
|
||||
return proxied.__class__.__name__
|
||||
return str(proxied)
|
||||
|
||||
@override
|
||||
def __dir__(self) -> Iterable[str]:
|
||||
proxied = self.__get_proxied__()
|
||||
if isinstance(proxied, LazyProxy):
|
||||
return []
|
||||
return proxied.__dir__()
|
||||
|
||||
@property # type: ignore
|
||||
@override
|
||||
def __class__(self) -> type: # pyright: ignore
|
||||
try:
|
||||
proxied = self.__get_proxied__()
|
||||
except Exception:
|
||||
return type(self)
|
||||
if issubclass(type(proxied), LazyProxy):
|
||||
return type(proxied)
|
||||
return proxied.__class__
|
||||
|
||||
def __get_proxied__(self) -> T:
|
||||
return self.__load__()
|
||||
|
||||
def __as_proxied__(self) -> T:
|
||||
"""Helper method that returns the current proxy, typed as the loaded object"""
|
||||
return cast(T, self)
|
||||
|
||||
@abstractmethod
|
||||
def __load__(self) -> T: ...
|
||||
42
src/tinker/_utils/_reflection.py
Normal file
42
src/tinker/_utils/_reflection.py
Normal file
|
|
@ -0,0 +1,42 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
from typing import Any, Callable
|
||||
|
||||
|
||||
def function_has_argument(func: Callable[..., Any], arg_name: str) -> bool:
|
||||
"""Returns whether or not the given function has a specific parameter"""
|
||||
sig = inspect.signature(func)
|
||||
return arg_name in sig.parameters
|
||||
|
||||
|
||||
def assert_signatures_in_sync(
|
||||
source_func: Callable[..., Any],
|
||||
check_func: Callable[..., Any],
|
||||
*,
|
||||
exclude_params: set[str] = set(),
|
||||
) -> None:
|
||||
"""Ensure that the signature of the second function matches the first."""
|
||||
|
||||
check_sig = inspect.signature(check_func)
|
||||
source_sig = inspect.signature(source_func)
|
||||
|
||||
errors: list[str] = []
|
||||
|
||||
for name, source_param in source_sig.parameters.items():
|
||||
if name in exclude_params:
|
||||
continue
|
||||
|
||||
custom_param = check_sig.parameters.get(name)
|
||||
if not custom_param:
|
||||
errors.append(f"the `{name}` param is missing")
|
||||
continue
|
||||
|
||||
if custom_param.annotation != source_param.annotation:
|
||||
errors.append(
|
||||
f"types for the `{name}` param are do not match; source={repr(source_param.annotation)} checking={repr(custom_param.annotation)}"
|
||||
)
|
||||
continue
|
||||
|
||||
if errors:
|
||||
raise AssertionError(f"{len(errors)} errors encountered when comparing signatures:\n\n" + "\n\n".join(errors))
|
||||
24
src/tinker/_utils/_resources_proxy.py
Normal file
24
src/tinker/_utils/_resources_proxy.py
Normal file
|
|
@ -0,0 +1,24 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from typing_extensions import override
|
||||
|
||||
from ._proxy import LazyProxy
|
||||
|
||||
|
||||
class ResourcesProxy(LazyProxy[Any]):
|
||||
"""A proxy for the `tinker.resources` module.
|
||||
|
||||
This is used so that we can lazily import `tinker.resources` only when
|
||||
needed *and* so that users can just import `tinker` and reference `tinker.resources`
|
||||
"""
|
||||
|
||||
@override
|
||||
def __load__(self) -> Any:
|
||||
import importlib
|
||||
|
||||
mod = importlib.import_module("tinker.resources")
|
||||
return mod
|
||||
|
||||
|
||||
resources = ResourcesProxy().__as_proxied__()
|
||||
12
src/tinker/_utils/_streams.py
Normal file
12
src/tinker/_utils/_streams.py
Normal file
|
|
@ -0,0 +1,12 @@
|
|||
from typing import Any
|
||||
from typing_extensions import Iterator, AsyncIterator
|
||||
|
||||
|
||||
def consume_sync_iterator(iterator: Iterator[Any]) -> None:
|
||||
for _ in iterator:
|
||||
...
|
||||
|
||||
|
||||
async def consume_async_iterator(iterator: AsyncIterator[Any]) -> None:
|
||||
async for _ in iterator:
|
||||
...
|
||||
86
src/tinker/_utils/_sync.py
Normal file
86
src/tinker/_utils/_sync.py
Normal file
|
|
@ -0,0 +1,86 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
import asyncio
|
||||
import functools
|
||||
import contextvars
|
||||
from typing import Any, TypeVar, Callable, Awaitable
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
import anyio
|
||||
import sniffio
|
||||
import anyio.to_thread
|
||||
|
||||
T_Retval = TypeVar("T_Retval")
|
||||
T_ParamSpec = ParamSpec("T_ParamSpec")
|
||||
|
||||
|
||||
if sys.version_info >= (3, 9):
|
||||
_asyncio_to_thread = asyncio.to_thread
|
||||
else:
|
||||
# backport of https://docs.python.org/3/library/asyncio-task.html#asyncio.to_thread
|
||||
# for Python 3.8 support
|
||||
async def _asyncio_to_thread(
|
||||
func: Callable[T_ParamSpec, T_Retval], /, *args: T_ParamSpec.args, **kwargs: T_ParamSpec.kwargs
|
||||
) -> Any:
|
||||
"""Asynchronously run function *func* in a separate thread.
|
||||
|
||||
Any *args and **kwargs supplied for this function are directly passed
|
||||
to *func*. Also, the current :class:`contextvars.Context` is propagated,
|
||||
allowing context variables from the main thread to be accessed in the
|
||||
separate thread.
|
||||
|
||||
Returns a coroutine that can be awaited to get the eventual result of *func*.
|
||||
"""
|
||||
loop = asyncio.events.get_running_loop()
|
||||
ctx = contextvars.copy_context()
|
||||
func_call = functools.partial(ctx.run, func, *args, **kwargs)
|
||||
return await loop.run_in_executor(None, func_call)
|
||||
|
||||
|
||||
async def to_thread(
|
||||
func: Callable[T_ParamSpec, T_Retval], /, *args: T_ParamSpec.args, **kwargs: T_ParamSpec.kwargs
|
||||
) -> T_Retval:
|
||||
if sniffio.current_async_library() == "asyncio":
|
||||
return await _asyncio_to_thread(func, *args, **kwargs)
|
||||
|
||||
return await anyio.to_thread.run_sync(
|
||||
functools.partial(func, *args, **kwargs),
|
||||
)
|
||||
|
||||
|
||||
# inspired by `asyncer`, https://github.com/tiangolo/asyncer
|
||||
def asyncify(function: Callable[T_ParamSpec, T_Retval]) -> Callable[T_ParamSpec, Awaitable[T_Retval]]:
|
||||
"""
|
||||
Take a blocking function and create an async one that receives the same
|
||||
positional and keyword arguments. For python version 3.9 and above, it uses
|
||||
asyncio.to_thread to run the function in a separate thread. For python version
|
||||
3.8, it uses locally defined copy of the asyncio.to_thread function which was
|
||||
introduced in python 3.9.
|
||||
|
||||
Usage:
|
||||
|
||||
```python
|
||||
def blocking_func(arg1, arg2, kwarg1=None):
|
||||
# blocking code
|
||||
return result
|
||||
|
||||
|
||||
result = asyncify(blocking_function)(arg1, arg2, kwarg1=value1)
|
||||
```
|
||||
|
||||
## Arguments
|
||||
|
||||
`function`: a blocking regular callable (e.g. a function)
|
||||
|
||||
## Return
|
||||
|
||||
An async function that takes the same positional and keyword arguments as the
|
||||
original one, that when called runs the same original function in a thread worker
|
||||
and returns the result.
|
||||
"""
|
||||
|
||||
async def wrapper(*args: T_ParamSpec.args, **kwargs: T_ParamSpec.kwargs) -> T_Retval:
|
||||
return await to_thread(function, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
447
src/tinker/_utils/_transform.py
Normal file
447
src/tinker/_utils/_transform.py
Normal file
|
|
@ -0,0 +1,447 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
import base64
|
||||
import pathlib
|
||||
from typing import Any, Mapping, TypeVar, cast
|
||||
from datetime import date, datetime
|
||||
from typing_extensions import Literal, get_args, override, get_type_hints as _get_type_hints
|
||||
|
||||
import anyio
|
||||
import pydantic
|
||||
|
||||
from ._utils import (
|
||||
is_list,
|
||||
is_given,
|
||||
lru_cache,
|
||||
is_mapping,
|
||||
is_iterable,
|
||||
)
|
||||
from .._files import is_base64_file_input
|
||||
from ._typing import (
|
||||
is_list_type,
|
||||
is_union_type,
|
||||
extract_type_arg,
|
||||
is_iterable_type,
|
||||
is_required_type,
|
||||
is_annotated_type,
|
||||
strip_annotated_type,
|
||||
)
|
||||
from .._compat import get_origin, model_dump, is_typeddict
|
||||
|
||||
_T = TypeVar("_T")
|
||||
|
||||
|
||||
# TODO: support for drilling globals() and locals()
|
||||
# TODO: ensure works correctly with forward references in all cases
|
||||
|
||||
|
||||
PropertyFormat = Literal["iso8601", "base64", "custom"]
|
||||
|
||||
|
||||
class PropertyInfo:
|
||||
"""Metadata class to be used in Annotated types to provide information about a given type.
|
||||
|
||||
For example:
|
||||
|
||||
class MyParams(TypedDict):
|
||||
account_holder_name: Annotated[str, PropertyInfo(alias='accountHolderName')]
|
||||
|
||||
This means that {'account_holder_name': 'Robert'} will be transformed to {'accountHolderName': 'Robert'} before being sent to the API.
|
||||
"""
|
||||
|
||||
alias: str | None
|
||||
format: PropertyFormat | None
|
||||
format_template: str | None
|
||||
discriminator: str | None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
alias: str | None = None,
|
||||
format: PropertyFormat | None = None,
|
||||
format_template: str | None = None,
|
||||
discriminator: str | None = None,
|
||||
) -> None:
|
||||
self.alias = alias
|
||||
self.format = format
|
||||
self.format_template = format_template
|
||||
self.discriminator = discriminator
|
||||
|
||||
@override
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}(alias='{self.alias}', format={self.format}, format_template='{self.format_template}', discriminator='{self.discriminator}')"
|
||||
|
||||
|
||||
def maybe_transform(
|
||||
data: object,
|
||||
expected_type: object,
|
||||
) -> Any | None:
|
||||
"""Wrapper over `transform()` that allows `None` to be passed.
|
||||
|
||||
See `transform()` for more details.
|
||||
"""
|
||||
if data is None:
|
||||
return None
|
||||
return transform(data, expected_type)
|
||||
|
||||
|
||||
# Wrapper over _transform_recursive providing fake types
|
||||
def transform(
|
||||
data: _T,
|
||||
expected_type: object,
|
||||
) -> _T:
|
||||
"""Transform dictionaries based off of type information from the given type, for example:
|
||||
|
||||
```py
|
||||
class Params(TypedDict, total=False):
|
||||
card_id: Required[Annotated[str, PropertyInfo(alias="cardID")]]
|
||||
|
||||
|
||||
transformed = transform({"card_id": "<my card ID>"}, Params)
|
||||
# {'cardID': '<my card ID>'}
|
||||
```
|
||||
|
||||
Any keys / data that does not have type information given will be included as is.
|
||||
|
||||
It should be noted that the transformations that this function does are not represented in the type system.
|
||||
"""
|
||||
transformed = _transform_recursive(data, annotation=cast(type, expected_type))
|
||||
return cast(_T, transformed)
|
||||
|
||||
|
||||
@lru_cache(maxsize=8096)
|
||||
def _get_annotated_type(type_: type) -> type | None:
|
||||
"""If the given type is an `Annotated` type then it is returned, if not `None` is returned.
|
||||
|
||||
This also unwraps the type when applicable, e.g. `Required[Annotated[T, ...]]`
|
||||
"""
|
||||
if is_required_type(type_):
|
||||
# Unwrap `Required[Annotated[T, ...]]` to `Annotated[T, ...]`
|
||||
type_ = get_args(type_)[0]
|
||||
|
||||
if is_annotated_type(type_):
|
||||
return type_
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _maybe_transform_key(key: str, type_: type) -> str:
|
||||
"""Transform the given `data` based on the annotations provided in `type_`.
|
||||
|
||||
Note: this function only looks at `Annotated` types that contain `PropertyInfo` metadata.
|
||||
"""
|
||||
annotated_type = _get_annotated_type(type_)
|
||||
if annotated_type is None:
|
||||
# no `Annotated` definition for this type, no transformation needed
|
||||
return key
|
||||
|
||||
# ignore the first argument as it is the actual type
|
||||
annotations = get_args(annotated_type)[1:]
|
||||
for annotation in annotations:
|
||||
if isinstance(annotation, PropertyInfo) and annotation.alias is not None:
|
||||
return annotation.alias
|
||||
|
||||
return key
|
||||
|
||||
|
||||
def _no_transform_needed(annotation: type) -> bool:
|
||||
return annotation == float or annotation == int
|
||||
|
||||
|
||||
def _transform_recursive(
|
||||
data: object,
|
||||
*,
|
||||
annotation: type,
|
||||
inner_type: type | None = None,
|
||||
) -> object:
|
||||
"""Transform the given data against the expected type.
|
||||
|
||||
Args:
|
||||
annotation: The direct type annotation given to the particular piece of data.
|
||||
This may or may not be wrapped in metadata types, e.g. `Required[T]`, `Annotated[T, ...]` etc
|
||||
|
||||
inner_type: If applicable, this is the "inside" type. This is useful in certain cases where the outside type
|
||||
is a container type such as `List[T]`. In that case `inner_type` should be set to `T` so that each entry in
|
||||
the list can be transformed using the metadata from the container type.
|
||||
|
||||
Defaults to the same value as the `annotation` argument.
|
||||
"""
|
||||
if inner_type is None:
|
||||
inner_type = annotation
|
||||
|
||||
stripped_type = strip_annotated_type(inner_type)
|
||||
origin = get_origin(stripped_type) or stripped_type
|
||||
if is_typeddict(stripped_type) and is_mapping(data):
|
||||
return _transform_typeddict(data, stripped_type)
|
||||
|
||||
if origin == dict and is_mapping(data):
|
||||
items_type = get_args(stripped_type)[1]
|
||||
return {key: _transform_recursive(value, annotation=items_type) for key, value in data.items()}
|
||||
|
||||
if (
|
||||
# List[T]
|
||||
(is_list_type(stripped_type) and is_list(data))
|
||||
# Iterable[T]
|
||||
or (is_iterable_type(stripped_type) and is_iterable(data) and not isinstance(data, str))
|
||||
):
|
||||
# dicts are technically iterable, but it is an iterable on the keys of the dict and is not usually
|
||||
# intended as an iterable, so we don't transform it.
|
||||
if isinstance(data, dict):
|
||||
return cast(object, data)
|
||||
|
||||
inner_type = extract_type_arg(stripped_type, 0)
|
||||
if _no_transform_needed(inner_type):
|
||||
# for some types there is no need to transform anything, so we can get a small
|
||||
# perf boost from skipping that work.
|
||||
#
|
||||
# but we still need to convert to a list to ensure the data is json-serializable
|
||||
if is_list(data):
|
||||
return data
|
||||
return list(data)
|
||||
|
||||
return [_transform_recursive(d, annotation=annotation, inner_type=inner_type) for d in data]
|
||||
|
||||
if is_union_type(stripped_type):
|
||||
# For union types we run the transformation against all subtypes to ensure that everything is transformed.
|
||||
#
|
||||
# TODO: there may be edge cases where the same normalized field name will transform to two different names
|
||||
# in different subtypes.
|
||||
for subtype in get_args(stripped_type):
|
||||
data = _transform_recursive(data, annotation=annotation, inner_type=subtype)
|
||||
return data
|
||||
|
||||
if isinstance(data, pydantic.BaseModel):
|
||||
return model_dump(data, exclude_unset=True, mode="json")
|
||||
|
||||
annotated_type = _get_annotated_type(annotation)
|
||||
if annotated_type is None:
|
||||
return data
|
||||
|
||||
# ignore the first argument as it is the actual type
|
||||
annotations = get_args(annotated_type)[1:]
|
||||
for annotation in annotations:
|
||||
if isinstance(annotation, PropertyInfo) and annotation.format is not None:
|
||||
return _format_data(data, annotation.format, annotation.format_template)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def _format_data(data: object, format_: PropertyFormat, format_template: str | None) -> object:
|
||||
if isinstance(data, (date, datetime)):
|
||||
if format_ == "iso8601":
|
||||
return data.isoformat()
|
||||
|
||||
if format_ == "custom" and format_template is not None:
|
||||
return data.strftime(format_template)
|
||||
|
||||
if format_ == "base64" and is_base64_file_input(data):
|
||||
binary: str | bytes | None = None
|
||||
|
||||
if isinstance(data, pathlib.Path):
|
||||
binary = data.read_bytes()
|
||||
elif isinstance(data, io.IOBase):
|
||||
binary = data.read()
|
||||
|
||||
if isinstance(binary, str): # type: ignore[unreachable]
|
||||
binary = binary.encode()
|
||||
|
||||
if not isinstance(binary, bytes):
|
||||
raise RuntimeError(f"Could not read bytes from {data}; Received {type(binary)}")
|
||||
|
||||
return base64.b64encode(binary).decode("ascii")
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def _transform_typeddict(
|
||||
data: Mapping[str, object],
|
||||
expected_type: type,
|
||||
) -> Mapping[str, object]:
|
||||
result: dict[str, object] = {}
|
||||
annotations = get_type_hints(expected_type, include_extras=True)
|
||||
for key, value in data.items():
|
||||
if not is_given(value):
|
||||
# we don't need to include `NotGiven` values here as they'll
|
||||
# be stripped out before the request is sent anyway
|
||||
continue
|
||||
|
||||
type_ = annotations.get(key)
|
||||
if type_ is None:
|
||||
# we do not have a type annotation for this field, leave it as is
|
||||
result[key] = value
|
||||
else:
|
||||
result[_maybe_transform_key(key, type_)] = _transform_recursive(value, annotation=type_)
|
||||
return result
|
||||
|
||||
|
||||
async def async_maybe_transform(
|
||||
data: object,
|
||||
expected_type: object,
|
||||
) -> Any | None:
|
||||
"""Wrapper over `async_transform()` that allows `None` to be passed.
|
||||
|
||||
See `async_transform()` for more details.
|
||||
"""
|
||||
if data is None:
|
||||
return None
|
||||
return await async_transform(data, expected_type)
|
||||
|
||||
|
||||
async def async_transform(
|
||||
data: _T,
|
||||
expected_type: object,
|
||||
) -> _T:
|
||||
"""Transform dictionaries based off of type information from the given type, for example:
|
||||
|
||||
```py
|
||||
class Params(TypedDict, total=False):
|
||||
card_id: Required[Annotated[str, PropertyInfo(alias="cardID")]]
|
||||
|
||||
|
||||
transformed = transform({"card_id": "<my card ID>"}, Params)
|
||||
# {'cardID': '<my card ID>'}
|
||||
```
|
||||
|
||||
Any keys / data that does not have type information given will be included as is.
|
||||
|
||||
It should be noted that the transformations that this function does are not represented in the type system.
|
||||
"""
|
||||
transformed = await _async_transform_recursive(data, annotation=cast(type, expected_type))
|
||||
return cast(_T, transformed)
|
||||
|
||||
|
||||
async def _async_transform_recursive(
|
||||
data: object,
|
||||
*,
|
||||
annotation: type,
|
||||
inner_type: type | None = None,
|
||||
) -> object:
|
||||
"""Transform the given data against the expected type.
|
||||
|
||||
Args:
|
||||
annotation: The direct type annotation given to the particular piece of data.
|
||||
This may or may not be wrapped in metadata types, e.g. `Required[T]`, `Annotated[T, ...]` etc
|
||||
|
||||
inner_type: If applicable, this is the "inside" type. This is useful in certain cases where the outside type
|
||||
is a container type such as `List[T]`. In that case `inner_type` should be set to `T` so that each entry in
|
||||
the list can be transformed using the metadata from the container type.
|
||||
|
||||
Defaults to the same value as the `annotation` argument.
|
||||
"""
|
||||
if inner_type is None:
|
||||
inner_type = annotation
|
||||
|
||||
stripped_type = strip_annotated_type(inner_type)
|
||||
origin = get_origin(stripped_type) or stripped_type
|
||||
if is_typeddict(stripped_type) and is_mapping(data):
|
||||
return await _async_transform_typeddict(data, stripped_type)
|
||||
|
||||
if origin == dict and is_mapping(data):
|
||||
items_type = get_args(stripped_type)[1]
|
||||
return {key: _transform_recursive(value, annotation=items_type) for key, value in data.items()}
|
||||
|
||||
if (
|
||||
# List[T]
|
||||
(is_list_type(stripped_type) and is_list(data))
|
||||
# Iterable[T]
|
||||
or (is_iterable_type(stripped_type) and is_iterable(data) and not isinstance(data, str))
|
||||
):
|
||||
# dicts are technically iterable, but it is an iterable on the keys of the dict and is not usually
|
||||
# intended as an iterable, so we don't transform it.
|
||||
if isinstance(data, dict):
|
||||
return cast(object, data)
|
||||
|
||||
inner_type = extract_type_arg(stripped_type, 0)
|
||||
if _no_transform_needed(inner_type):
|
||||
# for some types there is no need to transform anything, so we can get a small
|
||||
# perf boost from skipping that work.
|
||||
#
|
||||
# but we still need to convert to a list to ensure the data is json-serializable
|
||||
if is_list(data):
|
||||
return data
|
||||
return list(data)
|
||||
|
||||
return [await _async_transform_recursive(d, annotation=annotation, inner_type=inner_type) for d in data]
|
||||
|
||||
if is_union_type(stripped_type):
|
||||
# For union types we run the transformation against all subtypes to ensure that everything is transformed.
|
||||
#
|
||||
# TODO: there may be edge cases where the same normalized field name will transform to two different names
|
||||
# in different subtypes.
|
||||
for subtype in get_args(stripped_type):
|
||||
data = await _async_transform_recursive(data, annotation=annotation, inner_type=subtype)
|
||||
return data
|
||||
|
||||
if isinstance(data, pydantic.BaseModel):
|
||||
return model_dump(data, exclude_unset=True, mode="json")
|
||||
|
||||
annotated_type = _get_annotated_type(annotation)
|
||||
if annotated_type is None:
|
||||
return data
|
||||
|
||||
# ignore the first argument as it is the actual type
|
||||
annotations = get_args(annotated_type)[1:]
|
||||
for annotation in annotations:
|
||||
if isinstance(annotation, PropertyInfo) and annotation.format is not None:
|
||||
return await _async_format_data(data, annotation.format, annotation.format_template)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
async def _async_format_data(data: object, format_: PropertyFormat, format_template: str | None) -> object:
|
||||
if isinstance(data, (date, datetime)):
|
||||
if format_ == "iso8601":
|
||||
return data.isoformat()
|
||||
|
||||
if format_ == "custom" and format_template is not None:
|
||||
return data.strftime(format_template)
|
||||
|
||||
if format_ == "base64" and is_base64_file_input(data):
|
||||
binary: str | bytes | None = None
|
||||
|
||||
if isinstance(data, pathlib.Path):
|
||||
binary = await anyio.Path(data).read_bytes()
|
||||
elif isinstance(data, io.IOBase):
|
||||
binary = data.read()
|
||||
|
||||
if isinstance(binary, str): # type: ignore[unreachable]
|
||||
binary = binary.encode()
|
||||
|
||||
if not isinstance(binary, bytes):
|
||||
raise RuntimeError(f"Could not read bytes from {data}; Received {type(binary)}")
|
||||
|
||||
return base64.b64encode(binary).decode("ascii")
|
||||
|
||||
return data
|
||||
|
||||
|
||||
async def _async_transform_typeddict(
|
||||
data: Mapping[str, object],
|
||||
expected_type: type,
|
||||
) -> Mapping[str, object]:
|
||||
result: dict[str, object] = {}
|
||||
annotations = get_type_hints(expected_type, include_extras=True)
|
||||
for key, value in data.items():
|
||||
if not is_given(value):
|
||||
# we don't need to include `NotGiven` values here as they'll
|
||||
# be stripped out before the request is sent anyway
|
||||
continue
|
||||
|
||||
type_ = annotations.get(key)
|
||||
if type_ is None:
|
||||
# we do not have a type annotation for this field, leave it as is
|
||||
result[key] = value
|
||||
else:
|
||||
result[_maybe_transform_key(key, type_)] = await _async_transform_recursive(value, annotation=type_)
|
||||
return result
|
||||
|
||||
|
||||
@lru_cache(maxsize=8096)
|
||||
def get_type_hints(
|
||||
obj: Any,
|
||||
globalns: dict[str, Any] | None = None,
|
||||
localns: Mapping[str, Any] | None = None,
|
||||
include_extras: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
return _get_type_hints(obj, globalns=globalns, localns=localns, include_extras=include_extras)
|
||||
151
src/tinker/_utils/_typing.py
Normal file
151
src/tinker/_utils/_typing.py
Normal file
|
|
@ -0,0 +1,151 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
import typing
|
||||
import typing_extensions
|
||||
from typing import Any, TypeVar, Iterable, cast
|
||||
from collections import abc as _c_abc
|
||||
from typing_extensions import (
|
||||
TypeIs,
|
||||
Required,
|
||||
Annotated,
|
||||
get_args,
|
||||
get_origin,
|
||||
)
|
||||
|
||||
from ._utils import lru_cache
|
||||
from .._types import InheritsGeneric
|
||||
from .._compat import is_union as _is_union
|
||||
|
||||
|
||||
def is_annotated_type(typ: type) -> bool:
|
||||
return get_origin(typ) == Annotated
|
||||
|
||||
|
||||
def is_list_type(typ: type) -> bool:
|
||||
return (get_origin(typ) or typ) == list
|
||||
|
||||
|
||||
def is_iterable_type(typ: type) -> bool:
|
||||
"""If the given type is `typing.Iterable[T]`"""
|
||||
origin = get_origin(typ) or typ
|
||||
return origin == Iterable or origin == _c_abc.Iterable
|
||||
|
||||
|
||||
def is_union_type(typ: type) -> bool:
|
||||
return _is_union(get_origin(typ))
|
||||
|
||||
|
||||
def is_required_type(typ: type) -> bool:
|
||||
return get_origin(typ) == Required
|
||||
|
||||
|
||||
def is_typevar(typ: type) -> bool:
|
||||
# type ignore is required because type checkers
|
||||
# think this expression will always return False
|
||||
return type(typ) == TypeVar # type: ignore
|
||||
|
||||
|
||||
_TYPE_ALIAS_TYPES: tuple[type[typing_extensions.TypeAliasType], ...] = (typing_extensions.TypeAliasType,)
|
||||
if sys.version_info >= (3, 12):
|
||||
_TYPE_ALIAS_TYPES = (*_TYPE_ALIAS_TYPES, typing.TypeAliasType)
|
||||
|
||||
|
||||
def is_type_alias_type(tp: Any, /) -> TypeIs[typing_extensions.TypeAliasType]:
|
||||
"""Return whether the provided argument is an instance of `TypeAliasType`.
|
||||
|
||||
```python
|
||||
type Int = int
|
||||
is_type_alias_type(Int)
|
||||
# > True
|
||||
Str = TypeAliasType("Str", str)
|
||||
is_type_alias_type(Str)
|
||||
# > True
|
||||
```
|
||||
"""
|
||||
return isinstance(tp, _TYPE_ALIAS_TYPES)
|
||||
|
||||
|
||||
# Extracts T from Annotated[T, ...] or from Required[Annotated[T, ...]]
|
||||
@lru_cache(maxsize=8096)
|
||||
def strip_annotated_type(typ: type) -> type:
|
||||
if is_required_type(typ) or is_annotated_type(typ):
|
||||
return strip_annotated_type(cast(type, get_args(typ)[0]))
|
||||
|
||||
return typ
|
||||
|
||||
|
||||
def extract_type_arg(typ: type, index: int) -> type:
|
||||
args = get_args(typ)
|
||||
try:
|
||||
return cast(type, args[index])
|
||||
except IndexError as err:
|
||||
raise RuntimeError(f"Expected type {typ} to have a type argument at index {index} but it did not") from err
|
||||
|
||||
|
||||
def extract_type_var_from_base(
|
||||
typ: type,
|
||||
*,
|
||||
generic_bases: tuple[type, ...],
|
||||
index: int,
|
||||
failure_message: str | None = None,
|
||||
) -> type:
|
||||
"""Given a type like `Foo[T]`, returns the generic type variable `T`.
|
||||
|
||||
This also handles the case where a concrete subclass is given, e.g.
|
||||
```py
|
||||
class MyResponse(Foo[bytes]):
|
||||
...
|
||||
|
||||
extract_type_var(MyResponse, bases=(Foo,), index=0) -> bytes
|
||||
```
|
||||
|
||||
And where a generic subclass is given:
|
||||
```py
|
||||
_T = TypeVar('_T')
|
||||
class MyResponse(Foo[_T]):
|
||||
...
|
||||
|
||||
extract_type_var(MyResponse[bytes], bases=(Foo,), index=0) -> bytes
|
||||
```
|
||||
"""
|
||||
cls = cast(object, get_origin(typ) or typ)
|
||||
if cls in generic_bases: # pyright: ignore[reportUnnecessaryContains]
|
||||
# we're given the class directly
|
||||
return extract_type_arg(typ, index)
|
||||
|
||||
# if a subclass is given
|
||||
# ---
|
||||
# this is needed as __orig_bases__ is not present in the typeshed stubs
|
||||
# because it is intended to be for internal use only, however there does
|
||||
# not seem to be a way to resolve generic TypeVars for inherited subclasses
|
||||
# without using it.
|
||||
if isinstance(cls, InheritsGeneric):
|
||||
target_base_class: Any | None = None
|
||||
for base in cls.__orig_bases__:
|
||||
if base.__origin__ in generic_bases:
|
||||
target_base_class = base
|
||||
break
|
||||
|
||||
if target_base_class is None:
|
||||
raise RuntimeError(
|
||||
"Could not find the generic base class;\n"
|
||||
"This should never happen;\n"
|
||||
f"Does {cls} inherit from one of {generic_bases} ?"
|
||||
)
|
||||
|
||||
extracted = extract_type_arg(target_base_class, index)
|
||||
if is_typevar(extracted):
|
||||
# If the extracted type argument is itself a type variable
|
||||
# then that means the subclass itself is generic, so we have
|
||||
# to resolve the type argument from the class itself, not
|
||||
# the base class.
|
||||
#
|
||||
# Note: if there is more than 1 type argument, the subclass could
|
||||
# change the ordering of the type arguments, this is not currently
|
||||
# supported.
|
||||
return extract_type_arg(typ, index)
|
||||
|
||||
return extracted
|
||||
|
||||
raise RuntimeError(failure_message or f"Could not resolve inner type variable at index {index} for {typ}")
|
||||
422
src/tinker/_utils/_utils.py
Normal file
422
src/tinker/_utils/_utils.py
Normal file
|
|
@ -0,0 +1,422 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import re
|
||||
import inspect
|
||||
import functools
|
||||
from typing import (
|
||||
Any,
|
||||
Tuple,
|
||||
Mapping,
|
||||
TypeVar,
|
||||
Callable,
|
||||
Iterable,
|
||||
Sequence,
|
||||
cast,
|
||||
overload,
|
||||
)
|
||||
from pathlib import Path
|
||||
from datetime import date, datetime
|
||||
from typing_extensions import TypeGuard
|
||||
|
||||
import sniffio
|
||||
|
||||
from .._types import NotGiven, FileTypes, NotGivenOr, HeadersLike
|
||||
from .._compat import parse_date as parse_date, parse_datetime as parse_datetime
|
||||
|
||||
_T = TypeVar("_T")
|
||||
_TupleT = TypeVar("_TupleT", bound=Tuple[object, ...])
|
||||
_MappingT = TypeVar("_MappingT", bound=Mapping[str, object])
|
||||
_SequenceT = TypeVar("_SequenceT", bound=Sequence[object])
|
||||
CallableT = TypeVar("CallableT", bound=Callable[..., Any])
|
||||
|
||||
|
||||
def flatten(t: Iterable[Iterable[_T]]) -> list[_T]:
|
||||
return [item for sublist in t for item in sublist]
|
||||
|
||||
|
||||
def extract_files(
|
||||
# TODO: this needs to take Dict but variance issues.....
|
||||
# create protocol type ?
|
||||
query: Mapping[str, object],
|
||||
*,
|
||||
paths: Sequence[Sequence[str]],
|
||||
) -> list[tuple[str, FileTypes]]:
|
||||
"""Recursively extract files from the given dictionary based on specified paths.
|
||||
|
||||
A path may look like this ['foo', 'files', '<array>', 'data'].
|
||||
|
||||
Note: this mutates the given dictionary.
|
||||
"""
|
||||
files: list[tuple[str, FileTypes]] = []
|
||||
for path in paths:
|
||||
files.extend(_extract_items(query, path, index=0, flattened_key=None))
|
||||
return files
|
||||
|
||||
|
||||
def _extract_items(
|
||||
obj: object,
|
||||
path: Sequence[str],
|
||||
*,
|
||||
index: int,
|
||||
flattened_key: str | None,
|
||||
) -> list[tuple[str, FileTypes]]:
|
||||
try:
|
||||
key = path[index]
|
||||
except IndexError:
|
||||
if isinstance(obj, NotGiven):
|
||||
# no value was provided - we can safely ignore
|
||||
return []
|
||||
|
||||
# cyclical import
|
||||
from .._files import assert_is_file_content
|
||||
|
||||
# We have exhausted the path, return the entry we found.
|
||||
assert flattened_key is not None
|
||||
|
||||
if is_list(obj):
|
||||
files: list[tuple[str, FileTypes]] = []
|
||||
for entry in obj:
|
||||
assert_is_file_content(entry, key=flattened_key + "[]" if flattened_key else "")
|
||||
files.append((flattened_key + "[]", cast(FileTypes, entry)))
|
||||
return files
|
||||
|
||||
assert_is_file_content(obj, key=flattened_key)
|
||||
return [(flattened_key, cast(FileTypes, obj))]
|
||||
|
||||
index += 1
|
||||
if is_dict(obj):
|
||||
try:
|
||||
# We are at the last entry in the path so we must remove the field
|
||||
if (len(path)) == index:
|
||||
item = obj.pop(key)
|
||||
else:
|
||||
item = obj[key]
|
||||
except KeyError:
|
||||
# Key was not present in the dictionary, this is not indicative of an error
|
||||
# as the given path may not point to a required field. We also do not want
|
||||
# to enforce required fields as the API may differ from the spec in some cases.
|
||||
return []
|
||||
if flattened_key is None:
|
||||
flattened_key = key
|
||||
else:
|
||||
flattened_key += f"[{key}]"
|
||||
return _extract_items(
|
||||
item,
|
||||
path,
|
||||
index=index,
|
||||
flattened_key=flattened_key,
|
||||
)
|
||||
elif is_list(obj):
|
||||
if key != "<array>":
|
||||
return []
|
||||
|
||||
return flatten(
|
||||
[
|
||||
_extract_items(
|
||||
item,
|
||||
path,
|
||||
index=index,
|
||||
flattened_key=flattened_key + "[]" if flattened_key is not None else "[]",
|
||||
)
|
||||
for item in obj
|
||||
]
|
||||
)
|
||||
|
||||
# Something unexpected was passed, just ignore it.
|
||||
return []
|
||||
|
||||
|
||||
def is_given(obj: NotGivenOr[_T]) -> TypeGuard[_T]:
|
||||
return not isinstance(obj, NotGiven)
|
||||
|
||||
|
||||
# Type safe methods for narrowing types with TypeVars.
|
||||
# The default narrowing for isinstance(obj, dict) is dict[unknown, unknown],
|
||||
# however this cause Pyright to rightfully report errors. As we know we don't
|
||||
# care about the contained types we can safely use `object` in it's place.
|
||||
#
|
||||
# There are two separate functions defined, `is_*` and `is_*_t` for different use cases.
|
||||
# `is_*` is for when you're dealing with an unknown input
|
||||
# `is_*_t` is for when you're narrowing a known union type to a specific subset
|
||||
|
||||
|
||||
def is_tuple(obj: object) -> TypeGuard[tuple[object, ...]]:
|
||||
return isinstance(obj, tuple)
|
||||
|
||||
|
||||
def is_tuple_t(obj: _TupleT | object) -> TypeGuard[_TupleT]:
|
||||
return isinstance(obj, tuple)
|
||||
|
||||
|
||||
def is_sequence(obj: object) -> TypeGuard[Sequence[object]]:
|
||||
return isinstance(obj, Sequence)
|
||||
|
||||
|
||||
def is_sequence_t(obj: _SequenceT | object) -> TypeGuard[_SequenceT]:
|
||||
return isinstance(obj, Sequence)
|
||||
|
||||
|
||||
def is_mapping(obj: object) -> TypeGuard[Mapping[str, object]]:
|
||||
return isinstance(obj, Mapping)
|
||||
|
||||
|
||||
def is_mapping_t(obj: _MappingT | object) -> TypeGuard[_MappingT]:
|
||||
return isinstance(obj, Mapping)
|
||||
|
||||
|
||||
def is_dict(obj: object) -> TypeGuard[dict[object, object]]:
|
||||
return isinstance(obj, dict)
|
||||
|
||||
|
||||
def is_list(obj: object) -> TypeGuard[list[object]]:
|
||||
return isinstance(obj, list)
|
||||
|
||||
|
||||
def is_iterable(obj: object) -> TypeGuard[Iterable[object]]:
|
||||
return isinstance(obj, Iterable)
|
||||
|
||||
|
||||
def deepcopy_minimal(item: _T) -> _T:
|
||||
"""Minimal reimplementation of copy.deepcopy() that will only copy certain object types:
|
||||
|
||||
- mappings, e.g. `dict`
|
||||
- list
|
||||
|
||||
This is done for performance reasons.
|
||||
"""
|
||||
if is_mapping(item):
|
||||
return cast(_T, {k: deepcopy_minimal(v) for k, v in item.items()})
|
||||
if is_list(item):
|
||||
return cast(_T, [deepcopy_minimal(entry) for entry in item])
|
||||
return item
|
||||
|
||||
|
||||
# copied from https://github.com/Rapptz/RoboDanny
|
||||
def human_join(seq: Sequence[str], *, delim: str = ", ", final: str = "or") -> str:
|
||||
size = len(seq)
|
||||
if size == 0:
|
||||
return ""
|
||||
|
||||
if size == 1:
|
||||
return seq[0]
|
||||
|
||||
if size == 2:
|
||||
return f"{seq[0]} {final} {seq[1]}"
|
||||
|
||||
return delim.join(seq[:-1]) + f" {final} {seq[-1]}"
|
||||
|
||||
|
||||
def quote(string: str) -> str:
|
||||
"""Add single quotation marks around the given string. Does *not* do any escaping."""
|
||||
return f"'{string}'"
|
||||
|
||||
|
||||
def required_args(*variants: Sequence[str]) -> Callable[[CallableT], CallableT]:
|
||||
"""Decorator to enforce a given set of arguments or variants of arguments are passed to the decorated function.
|
||||
|
||||
Useful for enforcing runtime validation of overloaded functions.
|
||||
|
||||
Example usage:
|
||||
```py
|
||||
@overload
|
||||
def foo(*, a: str) -> str: ...
|
||||
|
||||
|
||||
@overload
|
||||
def foo(*, b: bool) -> str: ...
|
||||
|
||||
|
||||
# This enforces the same constraints that a static type checker would
|
||||
# i.e. that either a or b must be passed to the function
|
||||
@required_args(["a"], ["b"])
|
||||
def foo(*, a: str | None = None, b: bool | None = None) -> str: ...
|
||||
```
|
||||
"""
|
||||
|
||||
def inner(func: CallableT) -> CallableT:
|
||||
params = inspect.signature(func).parameters
|
||||
positional = [
|
||||
name
|
||||
for name, param in params.items()
|
||||
if param.kind
|
||||
in {
|
||||
param.POSITIONAL_ONLY,
|
||||
param.POSITIONAL_OR_KEYWORD,
|
||||
}
|
||||
]
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args: object, **kwargs: object) -> object:
|
||||
given_params: set[str] = set()
|
||||
for i, _ in enumerate(args):
|
||||
try:
|
||||
given_params.add(positional[i])
|
||||
except IndexError:
|
||||
raise TypeError(
|
||||
f"{func.__name__}() takes {len(positional)} argument(s) but {len(args)} were given"
|
||||
) from None
|
||||
|
||||
for key in kwargs.keys():
|
||||
given_params.add(key)
|
||||
|
||||
for variant in variants:
|
||||
matches = all((param in given_params for param in variant))
|
||||
if matches:
|
||||
break
|
||||
else: # no break
|
||||
if len(variants) > 1:
|
||||
variations = human_join(
|
||||
["(" + human_join([quote(arg) for arg in variant], final="and") + ")" for variant in variants]
|
||||
)
|
||||
msg = f"Missing required arguments; Expected either {variations} arguments to be given"
|
||||
else:
|
||||
assert len(variants) > 0
|
||||
|
||||
# TODO: this error message is not deterministic
|
||||
missing = list(set(variants[0]) - given_params)
|
||||
if len(missing) > 1:
|
||||
msg = f"Missing required arguments: {human_join([quote(arg) for arg in missing])}"
|
||||
else:
|
||||
msg = f"Missing required argument: {quote(missing[0])}"
|
||||
raise TypeError(msg)
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper # type: ignore
|
||||
|
||||
return inner
|
||||
|
||||
|
||||
_K = TypeVar("_K")
|
||||
_V = TypeVar("_V")
|
||||
|
||||
|
||||
@overload
|
||||
def strip_not_given(obj: None) -> None: ...
|
||||
|
||||
|
||||
@overload
|
||||
def strip_not_given(obj: Mapping[_K, _V | NotGiven]) -> dict[_K, _V]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def strip_not_given(obj: object) -> object: ...
|
||||
|
||||
|
||||
def strip_not_given(obj: object | None) -> object:
|
||||
"""Remove all top-level keys where their values are instances of `NotGiven`"""
|
||||
if obj is None:
|
||||
return None
|
||||
|
||||
if not is_mapping(obj):
|
||||
return obj
|
||||
|
||||
return {key: value for key, value in obj.items() if not isinstance(value, NotGiven)}
|
||||
|
||||
|
||||
def coerce_integer(val: str) -> int:
|
||||
return int(val, base=10)
|
||||
|
||||
|
||||
def coerce_float(val: str) -> float:
|
||||
return float(val)
|
||||
|
||||
|
||||
def coerce_boolean(val: str) -> bool:
|
||||
return val == "true" or val == "1" or val == "on"
|
||||
|
||||
|
||||
def maybe_coerce_integer(val: str | None) -> int | None:
|
||||
if val is None:
|
||||
return None
|
||||
return coerce_integer(val)
|
||||
|
||||
|
||||
def maybe_coerce_float(val: str | None) -> float | None:
|
||||
if val is None:
|
||||
return None
|
||||
return coerce_float(val)
|
||||
|
||||
|
||||
def maybe_coerce_boolean(val: str | None) -> bool | None:
|
||||
if val is None:
|
||||
return None
|
||||
return coerce_boolean(val)
|
||||
|
||||
|
||||
def removeprefix(string: str, prefix: str) -> str:
|
||||
"""Remove a prefix from a string.
|
||||
|
||||
Backport of `str.removeprefix` for Python < 3.9
|
||||
"""
|
||||
if string.startswith(prefix):
|
||||
return string[len(prefix) :]
|
||||
return string
|
||||
|
||||
|
||||
def removesuffix(string: str, suffix: str) -> str:
|
||||
"""Remove a suffix from a string.
|
||||
|
||||
Backport of `str.removesuffix` for Python < 3.9
|
||||
"""
|
||||
if string.endswith(suffix):
|
||||
return string[: -len(suffix)]
|
||||
return string
|
||||
|
||||
|
||||
def file_from_path(path: str) -> FileTypes:
|
||||
contents = Path(path).read_bytes()
|
||||
file_name = os.path.basename(path)
|
||||
return (file_name, contents)
|
||||
|
||||
|
||||
def get_required_header(headers: HeadersLike, header: str) -> str:
|
||||
lower_header = header.lower()
|
||||
if is_mapping_t(headers):
|
||||
# mypy doesn't understand the type narrowing here
|
||||
for k, v in headers.items(): # type: ignore
|
||||
if k.lower() == lower_header and isinstance(v, str):
|
||||
return v
|
||||
|
||||
# to deal with the case where the header looks like Stainless-Event-Id
|
||||
intercaps_header = re.sub(r"([^\w])(\w)", lambda pat: pat.group(1) + pat.group(2).upper(), header.capitalize())
|
||||
|
||||
for normalized_header in [header, lower_header, header.upper(), intercaps_header]:
|
||||
value = headers.get(normalized_header)
|
||||
if value:
|
||||
return value
|
||||
|
||||
raise ValueError(f"Could not find {header} header")
|
||||
|
||||
|
||||
def get_async_library() -> str:
|
||||
try:
|
||||
return sniffio.current_async_library()
|
||||
except Exception:
|
||||
return "false"
|
||||
|
||||
|
||||
def lru_cache(*, maxsize: int | None = 128) -> Callable[[CallableT], CallableT]:
|
||||
"""A version of functools.lru_cache that retains the type signature
|
||||
for the wrapped function arguments.
|
||||
"""
|
||||
wrapper = functools.lru_cache( # noqa: TID251
|
||||
maxsize=maxsize,
|
||||
)
|
||||
return cast(Any, wrapper) # type: ignore[no-any-return]
|
||||
|
||||
|
||||
def json_safe(data: object) -> object:
|
||||
"""Translates a mapping / sequence recursively in the same fashion
|
||||
as `pydantic` v2's `model_dump(mode="json")`.
|
||||
"""
|
||||
if is_mapping(data):
|
||||
return {json_safe(key): json_safe(value) for key, value in data.items()}
|
||||
|
||||
if is_iterable(data) and not isinstance(data, (str, bytes, bytearray)):
|
||||
return [json_safe(item) for item in data]
|
||||
|
||||
if isinstance(data, (datetime, date)):
|
||||
return data.isoformat()
|
||||
|
||||
return data
|
||||
4
src/tinker/_version.py
Normal file
4
src/tinker/_version.py
Normal file
|
|
@ -0,0 +1,4 @@
|
|||
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
||||
|
||||
__title__ = "tinker"
|
||||
__version__ = "0.0.1-alpha.1"
|
||||
4
src/tinker/lib/.keep
Normal file
4
src/tinker/lib/.keep
Normal file
|
|
@ -0,0 +1,4 @@
|
|||
File generated from our OpenAPI spec by Stainless.
|
||||
|
||||
This directory can be used to store custom files to expand the SDK.
|
||||
It is ignored by Stainless code generation and its content (other than this keep file) won't be touched.
|
||||
266
src/tinker/lib/api_future_impl.py
Normal file
266
src/tinker/lib/api_future_impl.py
Normal file
|
|
@ -0,0 +1,266 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import inspect
|
||||
import logging
|
||||
import time
|
||||
import traceback
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Any, Callable, List, Type, TypeVar, cast
|
||||
|
||||
import tinker
|
||||
from tinker import types
|
||||
from tinker.lib.client_connection_pool_type import ClientConnectionPoolType
|
||||
from tinker.lib.public_interfaces.api_future import APIFuture
|
||||
from tinker.lib.telemetry import Telemetry
|
||||
|
||||
from .._models import BaseModel
|
||||
from .retryable_exception import RetryableException
|
||||
from .sync_only import sync_only
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from tinker.lib.internal_client_holder import InternalClientHolder
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
T = TypeVar("T")
|
||||
U = TypeVar("U")
|
||||
|
||||
# Sentinel object to indicate that the function hasn't been called yet
|
||||
_UNCOMPUTED = object()
|
||||
|
||||
|
||||
class QueueState(Enum):
|
||||
ACTIVE = "active"
|
||||
PAUSED_RATE_LIMIT = "paused_rate_limit"
|
||||
PAUSED_CAPACITY = "paused_capacity"
|
||||
UNKNOWN = "unknown"
|
||||
|
||||
|
||||
class QueueStateObserver(ABC):
|
||||
@abstractmethod
|
||||
def on_queue_state_change(self, queue_state: QueueState) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class _APIFuture(APIFuture[T]): # pyright: ignore[reportUnusedClass]
|
||||
def __init__(
|
||||
self,
|
||||
model_cls: Type[T],
|
||||
holder: InternalClientHolder,
|
||||
untyped_future: types.UntypedAPIFuture,
|
||||
request_start_time: float,
|
||||
request_type: str,
|
||||
queue_state_observer: QueueStateObserver | None = None,
|
||||
):
|
||||
self.model_cls = model_cls
|
||||
self.holder = holder
|
||||
self.untyped_future = untyped_future
|
||||
self.request_type = request_type
|
||||
self._cached_result: Any = _UNCOMPUTED
|
||||
|
||||
# This helps us collect telemetry about how long (1) it takes the
|
||||
# client to serialize the request, (2) round-trip time to the server
|
||||
# and back, and (3) how long the server takes to process the request.
|
||||
# We send this delta in a header to the server when retrieving the promise
|
||||
# result.
|
||||
self.request_start_time = request_start_time
|
||||
self.request_future_start_time = time.time()
|
||||
self.request_queue_roundtrip_time = self.request_future_start_time - request_start_time
|
||||
self._future = self.holder.run_coroutine_threadsafe(self._result_async())
|
||||
self._queue_state_observer: QueueStateObserver | None = queue_state_observer
|
||||
|
||||
async def _result_async(self, timeout: float | None = None) -> T:
|
||||
"""Get the result of this future, with automatic retries for transient errors."""
|
||||
if self._cached_result is not _UNCOMPUTED:
|
||||
return cast(T, self._cached_result)
|
||||
|
||||
start_time = time.time()
|
||||
iteration = -1
|
||||
connection_error_retries = 0
|
||||
|
||||
while True:
|
||||
iteration += 1
|
||||
|
||||
if timeout is not None and time.time() - start_time > timeout:
|
||||
if telemetry := self.get_telemetry():
|
||||
current_time = time.time()
|
||||
telemetry.log(
|
||||
"APIFuture.result_async.timeout",
|
||||
event_data={
|
||||
"request_id": self.request_id,
|
||||
"request_type": self.request_type,
|
||||
"timeout": timeout,
|
||||
"iteration": iteration,
|
||||
"elapsed_time": current_time - start_time,
|
||||
},
|
||||
severity="ERROR",
|
||||
)
|
||||
raise TimeoutError(
|
||||
f"Timeout of {timeout} seconds reached while waiting for result of {self.request_id=}"
|
||||
)
|
||||
|
||||
# Headers for telemetry
|
||||
headers = {
|
||||
"X-Tinker-Request-Iteration": str(iteration),
|
||||
"X-Tinker-Request-Type": self.request_type,
|
||||
}
|
||||
if iteration == 0:
|
||||
headers["X-Tinker-Create-Promise-Roundtrip-Time"] = str(
|
||||
self.request_queue_roundtrip_time
|
||||
)
|
||||
|
||||
# Function hasn't been called yet, execute it now
|
||||
try:
|
||||
with self.holder.aclient(ClientConnectionPoolType.RETRIEVE_PROMISE) as client:
|
||||
response = await client.futures.with_raw_response.retrieve(
|
||||
request_id=self.request_id, timeout=45, extra_headers=headers, max_retries=0
|
||||
)
|
||||
except tinker.APIStatusError as e:
|
||||
connection_error_retries = 0
|
||||
should_retry = e.status_code == 408 or e.status_code in range(500, 600)
|
||||
if telemetry := self.get_telemetry():
|
||||
current_time = time.time()
|
||||
telemetry.log(
|
||||
"APIFuture.result_async.api_status_error",
|
||||
event_data={
|
||||
"request_id": self.request_id,
|
||||
"request_type": self.request_type,
|
||||
"status_code": e.status_code,
|
||||
"exception": str(e),
|
||||
"should_retry": should_retry,
|
||||
"iteration": iteration,
|
||||
"elapsed_time": current_time - start_time,
|
||||
},
|
||||
severity="WARNING" if should_retry else "ERROR",
|
||||
)
|
||||
|
||||
# Retry 408s until we time out
|
||||
if e.status_code == 408:
|
||||
if self._queue_state_observer is not None:
|
||||
with contextlib.suppress(Exception):
|
||||
response = e.response.json()
|
||||
if queue_state_str := response.get("queue_state", None):
|
||||
if queue_state_str == "active":
|
||||
queue_state = QueueState.ACTIVE
|
||||
elif queue_state_str == "paused_rate_limit":
|
||||
queue_state = QueueState.PAUSED_RATE_LIMIT
|
||||
elif queue_state_str == "paused_capacity":
|
||||
queue_state = QueueState.PAUSED_CAPACITY
|
||||
else:
|
||||
queue_state = QueueState.UNKNOWN
|
||||
self._queue_state_observer.on_queue_state_change(
|
||||
queue_state
|
||||
)
|
||||
continue
|
||||
if e.status_code == 410:
|
||||
raise RetryableException(
|
||||
message=f"Promise expired/broken for request {self.untyped_future.request_id}"
|
||||
) from e
|
||||
if e.status_code in range(500, 600):
|
||||
continue
|
||||
raise ValueError(
|
||||
f"Error retrieving result: {e} with status code {e.status_code=} for {self.request_id=} and expected type {self.model_cls=}"
|
||||
) from e
|
||||
except tinker.APIConnectionError as e:
|
||||
if telemetry := self.get_telemetry():
|
||||
current_time = time.time()
|
||||
telemetry.log(
|
||||
"APIFuture.result_async.connection_error",
|
||||
event_data={
|
||||
"request_id": self.request_id,
|
||||
"request_type": self.request_type,
|
||||
"exception": str(e),
|
||||
"connection_error_retries": connection_error_retries,
|
||||
"iteration": iteration,
|
||||
"elapsed_time": current_time - start_time,
|
||||
},
|
||||
severity="WARNING",
|
||||
)
|
||||
|
||||
# Retry all connection errors with exponential backoff
|
||||
await asyncio.sleep(min(2**connection_error_retries, 30))
|
||||
connection_error_retries += 1
|
||||
continue
|
||||
|
||||
# Function hasn't been called yet, execute it now
|
||||
result_dict: Any = await response.json()
|
||||
|
||||
if "type" in result_dict and result_dict["type"] == "try_again":
|
||||
logger.warning(f"Retrying request {self.request_id=} because of try_again")
|
||||
continue
|
||||
|
||||
if "error" in result_dict:
|
||||
raise ValueError(
|
||||
f"Error retrieving result: {result_dict} for {self.request_id=} and expected type {self.model_cls=}"
|
||||
)
|
||||
|
||||
try:
|
||||
# Check if model_cls is a BaseModel subclass before calling model_validate
|
||||
if inspect.isclass(self.model_cls) and issubclass(self.model_cls, BaseModel):
|
||||
self._cached_result = self.model_cls.model_validate(result_dict)
|
||||
else:
|
||||
# For non-BaseModel types, just return the result directly
|
||||
self._cached_result = result_dict
|
||||
return cast(T, self._cached_result)
|
||||
except Exception as e:
|
||||
if telemetry := self.get_telemetry():
|
||||
current_time = time.time()
|
||||
telemetry.log(
|
||||
"APIFuture.result_async.validation_error",
|
||||
event_data={
|
||||
"request_id": self.request_id,
|
||||
"request_type": self.request_type,
|
||||
"exception": str(e),
|
||||
"exception_type": type(e).__name__,
|
||||
"exception_stack": "".join(
|
||||
traceback.format_exception(type(e), e, e.__traceback__)
|
||||
)
|
||||
if e.__traceback__
|
||||
else None,
|
||||
"model_cls": str(self.model_cls),
|
||||
"iteration": iteration,
|
||||
"elapsed_time": current_time - start_time,
|
||||
},
|
||||
severity="ERROR",
|
||||
)
|
||||
|
||||
raise ValueError(
|
||||
f"Error retrieving result: {e} for {self.request_id=} and expected type {self.model_cls=}"
|
||||
) from e
|
||||
|
||||
@property
|
||||
def request_id(self) -> str:
|
||||
return self.untyped_future.request_id
|
||||
|
||||
@sync_only
|
||||
def result(self, timeout: float | None = None) -> T:
|
||||
return self._future.result(timeout)
|
||||
|
||||
async def result_async(self, timeout: float | None = None) -> T:
|
||||
return await asyncio.wait_for(self._future, timeout)
|
||||
|
||||
def get_telemetry(self) -> Telemetry | None:
|
||||
return self.holder.get_telemetry()
|
||||
|
||||
|
||||
class _CombinedAPIFuture(APIFuture[T]): # pyright: ignore[reportUnusedClass]
|
||||
def __init__(
|
||||
self,
|
||||
futures: List[APIFuture[T]],
|
||||
transform: Callable[[List[T]], T],
|
||||
holder: InternalClientHolder,
|
||||
):
|
||||
self.futures = futures
|
||||
self.transform = transform
|
||||
self.holder = holder
|
||||
|
||||
@sync_only
|
||||
def result(self, timeout: float | None = None) -> T:
|
||||
return self.holder.run_coroutine_threadsafe(self.result_async(timeout)).result()
|
||||
|
||||
async def result_async(self, timeout: float | None = None) -> T:
|
||||
results = await asyncio.gather(*[future.result_async(timeout) for future in self.futures])
|
||||
return self.transform(results)
|
||||
28
src/tinker/lib/async_tinker_provider.py
Normal file
28
src/tinker/lib/async_tinker_provider.py
Normal file
|
|
@ -0,0 +1,28 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Coroutine
|
||||
from contextlib import AbstractContextManager
|
||||
from typing import Any, Protocol, TypeVar
|
||||
|
||||
from tinker._client import AsyncTinker
|
||||
|
||||
from .public_interfaces.api_future import AwaitableConcurrentFuture
|
||||
from .client_connection_pool_type import ClientConnectionPoolType
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class AsyncTinkerProvider(Protocol):
|
||||
# both of the following methods should be threadsafe
|
||||
def get_loop(self) -> asyncio.AbstractEventLoop: ...
|
||||
|
||||
def run_coroutine_threadsafe(
|
||||
self,
|
||||
coro: Coroutine[Any, Any, T],
|
||||
) -> AwaitableConcurrentFuture[T]: ...
|
||||
|
||||
# must be called and used within the provided event loop
|
||||
def aclient(
|
||||
self, client_pool_type: ClientConnectionPoolType
|
||||
) -> AbstractContextManager[AsyncTinker]: ...
|
||||
126
src/tinker/lib/chunked_fwdbwd_helpers.py
Normal file
126
src/tinker/lib/chunked_fwdbwd_helpers.py
Normal file
|
|
@ -0,0 +1,126 @@
|
|||
import logging
|
||||
from typing import Any, Dict, List, Sequence, Set, cast
|
||||
|
||||
import numpy as np
|
||||
from tinker.types import ForwardBackwardOutput, LossFnOutput
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
Metrics = Dict[str, float]
|
||||
|
||||
|
||||
def combine_fwd_bwd_output_results(
|
||||
results: Sequence[ForwardBackwardOutput],
|
||||
) -> ForwardBackwardOutput:
|
||||
if not results:
|
||||
return ForwardBackwardOutput(loss_fn_output_type="", metrics={}, loss_fn_outputs=[])
|
||||
|
||||
combined_metrics = _metrics_reduction(results)
|
||||
combined_outputs = _combine_loss_fn_outputs(results)
|
||||
|
||||
return ForwardBackwardOutput(
|
||||
loss_fn_output_type=results[0].loss_fn_output_type,
|
||||
metrics=combined_metrics,
|
||||
loss_fn_outputs=combined_outputs,
|
||||
)
|
||||
|
||||
|
||||
def _combine_loss_fn_outputs(results: Sequence[ForwardBackwardOutput]) -> List[LossFnOutput]:
|
||||
return [output for result in results for output in result.loss_fn_outputs]
|
||||
|
||||
|
||||
def _order_insensitive_hash(xs: Sequence[Set[Any]] | Sequence[float]) -> int:
|
||||
"""Combine hash values in an order-insensitive way.
|
||||
|
||||
Args:
|
||||
xs: Either a sequence of sets (original data) or a sequence of already-computed hash values
|
||||
"""
|
||||
# If we have sets, flatten and hash them (original behavior)
|
||||
if xs and isinstance(xs[0], set):
|
||||
return hash(tuple(sorted([y for x in xs for y in cast(Set[Any], x)])))
|
||||
|
||||
# If we have already-computed hash values, combine them deterministically
|
||||
# Sort them to ensure order-insensitive combination
|
||||
return hash(tuple(sorted(int(cast(float, x)) for x in xs)))
|
||||
|
||||
|
||||
def _mean(xs: Sequence[float | int], weights: Sequence[float] | None = None) -> float:
|
||||
if weights is None:
|
||||
return np.mean(xs).item()
|
||||
return np.average(xs, weights=weights).item()
|
||||
|
||||
|
||||
def _sum(xs: Sequence[float | int]) -> float:
|
||||
return np.sum(xs).item()
|
||||
|
||||
|
||||
def _min(xs: Sequence[float | int]) -> float:
|
||||
return np.min(xs).item()
|
||||
|
||||
|
||||
def _max(xs: Sequence[float | int]) -> float:
|
||||
return np.max(xs).item()
|
||||
|
||||
|
||||
def _slack(xs: Sequence[float | int], weights: Sequence[float] | None = None) -> float:
|
||||
if weights is None:
|
||||
return (np.max(xs) - np.mean(xs)).item()
|
||||
return (np.max(xs) - np.average(xs, weights=weights)).item()
|
||||
|
||||
|
||||
def _unique(xs: Sequence[float | int]) -> Sequence[float | int]:
|
||||
"""
|
||||
A unique metric can't actually fold. But in order to work around the fact that
|
||||
it's a str:float dict, we just insert unique keys with some suffix for each
|
||||
unique value. It's a hack, that's for sure.
|
||||
|
||||
This is a dummy identity function just for documentation and consistency purposes.
|
||||
"""
|
||||
return xs
|
||||
|
||||
|
||||
REDUCE_MAP = {
|
||||
"mean": _mean,
|
||||
"sum": _sum,
|
||||
"min": _min,
|
||||
"max": _max,
|
||||
"slack": _slack,
|
||||
"hash_unordered": _order_insensitive_hash,
|
||||
"unique": _unique,
|
||||
}
|
||||
|
||||
|
||||
def _metrics_reduction(results: Sequence[ForwardBackwardOutput]) -> Metrics:
|
||||
"""Reduce metrics from all actors.
|
||||
every metric must indicate a reduction_type in its name for example "mfu:mean"
|
||||
|
||||
Metrics are weighted by the number of loss_fn_outputs (data points) each actor processed.
|
||||
"""
|
||||
if not results:
|
||||
return {}
|
||||
keys = results[0].metrics.keys()
|
||||
|
||||
weights = [len(m.loss_fn_outputs) for m in results]
|
||||
|
||||
res = {}
|
||||
for key in keys:
|
||||
name, reduction = key.split(":")
|
||||
if reduction not in REDUCE_MAP:
|
||||
# Can happen when a new reduction type is added
|
||||
logger.debug(
|
||||
f"Invalid {reduction=} for metric {name=}. Expecting one of {REDUCE_MAP.keys()}"
|
||||
)
|
||||
continue
|
||||
if not all(key in m.metrics for m in results):
|
||||
continue
|
||||
reduce_fn = REDUCE_MAP[reduction]
|
||||
values = [m.metrics[key] for m in results]
|
||||
|
||||
if reduction in ["mean", "slack"]:
|
||||
res[key] = reduce_fn(values, weights)
|
||||
elif reduction in ["unique"]:
|
||||
res[key] = values[0]
|
||||
res.update({f"{key}_{i + 1}": v for i, v in enumerate(values[1:])})
|
||||
else:
|
||||
res[key] = reduce_fn(values)
|
||||
return res
|
||||
8
src/tinker/lib/client_connection_pool_type.py
Normal file
8
src/tinker/lib/client_connection_pool_type.py
Normal file
|
|
@ -0,0 +1,8 @@
|
|||
from enum import Enum
|
||||
|
||||
|
||||
class ClientConnectionPoolType(Enum):
|
||||
SAMPLE = "sample"
|
||||
TRAIN = "train"
|
||||
RETRIEVE_PROMISE = "retrieve_promise"
|
||||
TELEMETRY = "telemetry"
|
||||
234
src/tinker/lib/internal_client_holder.py
Normal file
234
src/tinker/lib/internal_client_holder.py
Normal file
|
|
@ -0,0 +1,234 @@
|
|||
"""Internal client holder for managing AsyncTinker clients."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
import uuid
|
||||
from collections.abc import Coroutine, Generator
|
||||
from contextlib import AbstractContextManager, contextmanager
|
||||
from typing import Any, Awaitable, Callable, TypeVar
|
||||
|
||||
import httpx
|
||||
|
||||
from tinker._client import AsyncTinker
|
||||
from tinker._exceptions import APIConnectionError, APIStatusError
|
||||
from tinker.lib.async_tinker_provider import AsyncTinkerProvider
|
||||
from tinker.lib.client_connection_pool_type import ClientConnectionPoolType
|
||||
from tinker.lib.public_interfaces.api_future import AwaitableConcurrentFuture
|
||||
from tinker.lib.telemetry import Telemetry, init_telemetry
|
||||
from tinker.lib.telemetry_provider import TelemetryProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
MAX_REQUESTS_PER_HTTPX_CLIENT = 50
|
||||
|
||||
|
||||
class ClientConnectionPool:
|
||||
def __init__(
|
||||
self,
|
||||
loop: asyncio.AbstractEventLoop,
|
||||
max_requests_per_client: int,
|
||||
constructor_kwargs: dict[str, Any],
|
||||
):
|
||||
self._loop = loop
|
||||
self._max_requests_per_client = max_requests_per_client
|
||||
self._constructor_kwargs = constructor_kwargs
|
||||
self._clients: list[AsyncTinker] = []
|
||||
self._client_active_refcount: list[int] = []
|
||||
|
||||
@contextmanager
|
||||
def aclient(self) -> Generator[AsyncTinker, None, None]:
|
||||
assert _current_loop() is self._loop, "AsyncTinker client called from incorrect event loop"
|
||||
client_idx = -1
|
||||
for i, ref_count in enumerate(self._client_active_refcount):
|
||||
if ref_count < self._max_requests_per_client:
|
||||
client_idx = i
|
||||
break
|
||||
if client_idx == -1:
|
||||
self._clients.append(AsyncTinker(**self._constructor_kwargs))
|
||||
client_idx = len(self._clients) - 1
|
||||
self._client_active_refcount.append(0)
|
||||
|
||||
self._client_active_refcount[client_idx] += 1
|
||||
try:
|
||||
yield self._clients[client_idx]
|
||||
finally:
|
||||
self._client_active_refcount[client_idx] -= 1
|
||||
|
||||
|
||||
class InternalClientHolderThreadSingleton:
|
||||
def __init__(self):
|
||||
self._loop: asyncio.AbstractEventLoop | None = None
|
||||
self._thread: threading.Thread | None = None
|
||||
self._started: bool = False
|
||||
self._lifecycle_lock: threading.Lock = threading.Lock()
|
||||
|
||||
def _ensure_started(self):
|
||||
if self._started:
|
||||
return
|
||||
|
||||
with self._lifecycle_lock:
|
||||
if self._started:
|
||||
return
|
||||
self._loop = asyncio.new_event_loop()
|
||||
self._thread = threading.Thread(target=self._background_thread_func, daemon=True)
|
||||
self._thread.start()
|
||||
self._started = True
|
||||
|
||||
def _background_thread_func(self):
|
||||
assert self._loop is not None, "Loop must not be None"
|
||||
self._loop.run_forever()
|
||||
|
||||
def get_loop(self) -> asyncio.AbstractEventLoop:
|
||||
self._ensure_started()
|
||||
assert self._loop is not None, "Loop must not be None"
|
||||
return self._loop
|
||||
|
||||
|
||||
_internal_client_holder_thread_singleton = InternalClientHolderThreadSingleton()
|
||||
|
||||
|
||||
class InternalClientHolder(AsyncTinkerProvider, TelemetryProvider):
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
self._constructor_kwargs = kwargs
|
||||
# So we can use async eventloop for parallel sampling requests
|
||||
# in sync code.
|
||||
self._loop: asyncio.AbstractEventLoop = _internal_client_holder_thread_singleton.get_loop()
|
||||
self._client_pools: dict[ClientConnectionPoolType, ClientConnectionPool] = {}
|
||||
self._sample_backoff_until: float | None = None
|
||||
self._sample_dispatch_semaphore: asyncio.Semaphore = asyncio.Semaphore(400)
|
||||
self._session_id: str = str(uuid.uuid4())
|
||||
self._telemetry: Telemetry | None = init_telemetry(self, session_id=self._session_id)
|
||||
|
||||
self._training_client_counter: int = 0
|
||||
self._training_client_lock: threading.Lock = threading.Lock()
|
||||
|
||||
self._unordered_id_counter: int = 0
|
||||
|
||||
def _get_client_connection_pool(
|
||||
self, client_pool_type: ClientConnectionPoolType
|
||||
) -> ClientConnectionPool:
|
||||
if client_pool_type not in self._client_pools:
|
||||
max_requests_per_client = (
|
||||
1
|
||||
if client_pool_type == ClientConnectionPoolType.TRAIN
|
||||
else MAX_REQUESTS_PER_HTTPX_CLIENT
|
||||
)
|
||||
self._client_pools[client_pool_type] = ClientConnectionPool(
|
||||
self.get_loop(), max_requests_per_client, self._constructor_kwargs
|
||||
)
|
||||
return self._client_pools[client_pool_type]
|
||||
|
||||
def get_training_client_id(self) -> int:
|
||||
with self._training_client_lock:
|
||||
training_client_id = self._training_client_counter
|
||||
self._training_client_counter += 1
|
||||
return training_client_id
|
||||
|
||||
def aclient(
|
||||
self, client_pool_type: ClientConnectionPoolType
|
||||
) -> AbstractContextManager[AsyncTinker]:
|
||||
return self._get_client_connection_pool(client_pool_type).aclient()
|
||||
|
||||
def get_loop(self) -> asyncio.AbstractEventLoop:
|
||||
return self._loop
|
||||
|
||||
def get_telemetry(self) -> Telemetry | None:
|
||||
return self._telemetry
|
||||
|
||||
def run_coroutine_threadsafe(
|
||||
self,
|
||||
coro: Coroutine[Any, Any, T],
|
||||
) -> AwaitableConcurrentFuture[T]:
|
||||
return AwaitableConcurrentFuture(asyncio.run_coroutine_threadsafe(coro, self.get_loop()))
|
||||
|
||||
def close(self):
|
||||
if telemetry := self._telemetry:
|
||||
telemetry.stop()
|
||||
|
||||
def __del__(self):
|
||||
self.close()
|
||||
|
||||
def make_training_client_idempotency_key(self, training_client_id: int, request_id: int) -> str:
|
||||
return f"{self._session_id}:{training_client_id}:{request_id}"
|
||||
|
||||
def make_idempotency_key(self) -> str:
|
||||
self._unordered_id_counter += 1
|
||||
return f"{self._session_id}:unordered:{self._unordered_id_counter}"
|
||||
|
||||
@staticmethod
|
||||
def _is_retryable_status_code(status_code: int) -> bool:
|
||||
return status_code in (408, 409, 429) or (500 <= status_code < 600)
|
||||
|
||||
@staticmethod
|
||||
def _is_retryable_exception(exception: Exception) -> bool:
|
||||
RETRYABLE_EXCEPTIONS = (
|
||||
asyncio.TimeoutError,
|
||||
APIConnectionError,
|
||||
httpx.TimeoutException,
|
||||
)
|
||||
if isinstance(exception, RETRYABLE_EXCEPTIONS):
|
||||
return True
|
||||
if isinstance(exception, APIStatusError):
|
||||
return InternalClientHolder._is_retryable_status_code(exception.status_code)
|
||||
return False
|
||||
|
||||
async def execute_with_retries(
|
||||
self, func: Callable[..., Awaitable[T]], *args: Any, **kwargs: Any
|
||||
) -> T:
|
||||
MAX_WAIT_TIME = 60 * 5
|
||||
start_time = time.time()
|
||||
attempt_count = 0
|
||||
while True:
|
||||
try:
|
||||
return await func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
is_retryable = self._is_retryable_exception(e)
|
||||
current_time = time.time()
|
||||
elapsed_time = current_time - start_time
|
||||
if telemetry := self.get_telemetry():
|
||||
telemetry.log(
|
||||
"InternalClientHolder.execute_with_retries.exception",
|
||||
event_data={
|
||||
"func": getattr(
|
||||
func, "__qualname__", getattr(func, "__name__", type(func).__name__)
|
||||
),
|
||||
"exception": str(e),
|
||||
"exception_type": type(e).__name__,
|
||||
"exception_stack": "".join(
|
||||
traceback.format_exception(type(e), e, e.__traceback__)
|
||||
)
|
||||
if e.__traceback__
|
||||
else None,
|
||||
"status_code": getattr(e, "status_code", None),
|
||||
"is_retryable": is_retryable,
|
||||
"attempt_count": attempt_count,
|
||||
"start_time": start_time,
|
||||
"current_time": current_time,
|
||||
"elapsed_time": elapsed_time,
|
||||
},
|
||||
severity="WARNING" if is_retryable else "ERROR",
|
||||
)
|
||||
if is_retryable and elapsed_time < MAX_WAIT_TIME:
|
||||
# Apply exponential backoff
|
||||
time_to_wait = min(2**attempt_count, 30)
|
||||
attempt_count += 1
|
||||
# Don't wait too long if we're almost at the max wait time
|
||||
time_to_wait = min(time_to_wait, start_time + MAX_WAIT_TIME - current_time)
|
||||
await asyncio.sleep(time_to_wait)
|
||||
continue
|
||||
|
||||
raise e
|
||||
|
||||
|
||||
def _current_loop() -> asyncio.AbstractEventLoop | None:
|
||||
try:
|
||||
return asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
return None
|
||||
14
src/tinker/lib/public_interfaces/__init__.py
Normal file
14
src/tinker/lib/public_interfaces/__init__.py
Normal file
|
|
@ -0,0 +1,14 @@
|
|||
"""Public interfaces for the Tinker client library."""
|
||||
|
||||
from .api_future import APIFuture, AwaitableConcurrentFuture
|
||||
from .sampling_client import SamplingClient
|
||||
from .service_client import ServiceClient
|
||||
from .training_client import TrainingClient
|
||||
|
||||
__all__ = [
|
||||
"ServiceClient",
|
||||
"TrainingClient",
|
||||
"SamplingClient",
|
||||
"APIFuture",
|
||||
"AwaitableConcurrentFuture",
|
||||
]
|
||||
40
src/tinker/lib/public_interfaces/api_future.py
Normal file
40
src/tinker/lib/public_interfaces/api_future.py
Normal file
|
|
@ -0,0 +1,40 @@
|
|||
"""
|
||||
API Future classes for handling async operations with retry logic.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from abc import ABC, abstractmethod
|
||||
from concurrent.futures import Future as ConcurrentFuture
|
||||
from typing import Generic, TypeVar
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class APIFuture(ABC, Generic[T]):
|
||||
@abstractmethod
|
||||
async def result_async(self, timeout: float | None = None) -> T:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def result(self, timeout: float | None = None) -> T:
|
||||
raise NotImplementedError
|
||||
|
||||
def __await__(self):
|
||||
return self.result_async().__await__()
|
||||
|
||||
|
||||
class AwaitableConcurrentFuture(APIFuture[T]):
|
||||
def __init__(self, future: ConcurrentFuture[T]):
|
||||
self._future: ConcurrentFuture[T] = future
|
||||
|
||||
def result(self, timeout: float | None = None) -> T:
|
||||
return self._future.result(timeout)
|
||||
|
||||
async def result_async(self, timeout: float | None = None) -> T:
|
||||
async with asyncio.timeout(timeout):
|
||||
return await asyncio.wrap_future(self._future)
|
||||
|
||||
def future(self) -> ConcurrentFuture[T]:
|
||||
return self._future
|
||||
397
src/tinker/lib/public_interfaces/rest_client.py
Normal file
397
src/tinker/lib/public_interfaces/rest_client.py
Normal file
|
|
@ -0,0 +1,397 @@
|
|||
"""RestClient for Tinker API REST operations."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from concurrent.futures import Future as ConcurrentFuture
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from tinker import types, NoneType
|
||||
from tinker.lib.public_interfaces.api_future import AwaitableConcurrentFuture
|
||||
from tinker.lib.client_connection_pool_type import ClientConnectionPoolType
|
||||
from tinker.lib.telemetry import Telemetry, capture_exceptions
|
||||
from tinker.lib.telemetry_provider import TelemetryProvider
|
||||
|
||||
from ..sync_only import sync_only
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..internal_client_holder import InternalClientHolder
|
||||
|
||||
# pyright: reportPrivateImportUsage=false
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RestClient(TelemetryProvider):
|
||||
"""Client for REST API operations like listing checkpoints and metadata.
|
||||
|
||||
The RestClient provides access to various REST endpoints for querying
|
||||
model information, checkpoints, and other resources. You typically get one
|
||||
by calling `service_client.create_rest_client()`.
|
||||
|
||||
Key methods:
|
||||
- list_checkpoints() - list available model checkpoints (both training and sampler)
|
||||
- get_training_run() - get model information and metadata as ModelEntry
|
||||
- delete_checkpoint() - delete an existing checkpoint for a training run
|
||||
- download_sampler_weights_archive() - download sampler weights checkpoint as tar.gz archive
|
||||
|
||||
Args:
|
||||
holder: Internal client managing HTTP connections and async operations
|
||||
|
||||
Example:
|
||||
>>> rest_client = service_client.create_rest_client()
|
||||
>>> training_run = rest_client.get_training_run("run-id").result()
|
||||
>>> print(f"Training Run: {training_run.training_run_id}, LoRA: {training_run.is_lora}")
|
||||
>>> checkpoints = rest_client.list_checkpoints("run-id").result()
|
||||
>>> print(f"Found {len(checkpoints.checkpoints)} checkpoints")
|
||||
>>> for checkpoint in checkpoints.checkpoints:
|
||||
... print(f" {checkpoint.checkpoint_type}: {checkpoint.checkpoint_id}")
|
||||
"""
|
||||
|
||||
def __init__(self, holder: InternalClientHolder):
|
||||
self.holder = holder
|
||||
|
||||
def _get_training_run_submit(
|
||||
self, training_run_id: types.ModelID
|
||||
) -> AwaitableConcurrentFuture[types.TrainingRun]:
|
||||
"""Internal method to submit get model request."""
|
||||
async def _get_training_run_async() -> types.TrainingRun:
|
||||
async def _send_request() -> types.TrainingRun:
|
||||
with self.holder.aclient(ClientConnectionPoolType.TRAIN) as client:
|
||||
return await client.get(
|
||||
f"/api/v1/training_runs/{training_run_id}",
|
||||
cast_to=types.TrainingRun,
|
||||
)
|
||||
|
||||
return await self.holder.execute_with_retries(_send_request)
|
||||
|
||||
return self.holder.run_coroutine_threadsafe(_get_training_run_async())
|
||||
|
||||
@sync_only
|
||||
@capture_exceptions(fatal=True)
|
||||
def get_training_run(self, training_run_id: types.ModelID) -> ConcurrentFuture[types.TrainingRun]:
|
||||
"""Get training run info.
|
||||
|
||||
Args:
|
||||
training_run_id: The training run ID to get information for
|
||||
|
||||
Returns:
|
||||
A Future containing the training run information
|
||||
|
||||
Example:
|
||||
>>> future = rest_client.get_training_run("run-id")
|
||||
>>> response = future.result()
|
||||
>>> print(f"Training Run ID: {response.training_run_id}, Base: {response.base_model}")
|
||||
"""
|
||||
return self._get_training_run_submit(training_run_id).future()
|
||||
|
||||
@capture_exceptions(fatal=True)
|
||||
async def get_training_run_async(self, training_run_id: types.ModelID) -> types.TrainingRun:
|
||||
"""Async version of get_training_run.
|
||||
|
||||
Args:
|
||||
training_run_id: The training run ID to get information for
|
||||
|
||||
Returns:
|
||||
Training run information
|
||||
|
||||
Example:
|
||||
>>> response = await rest_client.get_training_run_async("run-id")
|
||||
>>> print(f"Training Run ID: {response.training_run_id}, Base: {response.base_model}")
|
||||
"""
|
||||
return await self._get_training_run_submit(training_run_id)
|
||||
|
||||
@sync_only
|
||||
@capture_exceptions(fatal=True)
|
||||
def get_training_run_by_tinker_path(self, tinker_path: str) -> ConcurrentFuture[types.TrainingRun]:
|
||||
"""Get training run info.
|
||||
|
||||
Args:
|
||||
tinker_path: The tinker path to the checkpoint
|
||||
|
||||
Returns:
|
||||
A Future containing the training run information
|
||||
|
||||
Example:
|
||||
>>> future = rest_client.get_training_run_by_tinker_path("tinker://run-id/weights/checkpoint-001")
|
||||
>>> response = future.result()
|
||||
>>> print(f"Training Run ID: {response.training_run_id}, Base: {response.base_model}")
|
||||
"""
|
||||
parsed_checkpoint_tinker_path = types.ParsedCheckpointTinkerPath.from_tinker_path(
|
||||
tinker_path
|
||||
)
|
||||
return self.get_training_run(parsed_checkpoint_tinker_path.training_run_id)
|
||||
|
||||
@capture_exceptions(fatal=True)
|
||||
async def get_training_run_by_tinker_path_async(self, tinker_path: str) -> types.TrainingRun:
|
||||
"""Async version of get_training_run.
|
||||
|
||||
Args:
|
||||
tinker_path: The tinker path to the checkpoint
|
||||
|
||||
Returns:
|
||||
Training run information
|
||||
|
||||
Example:
|
||||
>>> response = await rest_client.get_training_run_by_tinker_path_async("tinker://run-id/weights/checkpoint-001")
|
||||
>>> print(f"Training Run ID: {response.training_run_id}, Base: {response.base_model}")
|
||||
"""
|
||||
parsed_checkpoint_tinker_path = types.ParsedCheckpointTinkerPath.from_tinker_path(
|
||||
tinker_path
|
||||
)
|
||||
return await self.get_training_run_async(parsed_checkpoint_tinker_path.training_run_id)
|
||||
|
||||
def _list_training_runs_submit(
|
||||
self, limit: int = 20, offset: int = 0
|
||||
) -> AwaitableConcurrentFuture[types.TrainingRunsResponse]:
|
||||
"""Internal method to submit list training runs request."""
|
||||
async def _list_training_runs_async() -> types.TrainingRunsResponse:
|
||||
async def _send_request() -> types.TrainingRunsResponse:
|
||||
with self.holder.aclient(ClientConnectionPoolType.TRAIN) as client:
|
||||
params: dict[str, object] = {"limit": limit, "offset": offset}
|
||||
|
||||
return await client.get(
|
||||
"/api/v1/training_runs",
|
||||
options={"params": params},
|
||||
cast_to=types.TrainingRunsResponse,
|
||||
)
|
||||
|
||||
return await self.holder.execute_with_retries(_send_request)
|
||||
|
||||
return self.holder.run_coroutine_threadsafe(_list_training_runs_async())
|
||||
|
||||
@sync_only
|
||||
@capture_exceptions(fatal=True)
|
||||
def list_training_runs(
|
||||
self, limit: int = 20, offset: int = 0
|
||||
) -> ConcurrentFuture[types.TrainingRunsResponse]:
|
||||
"""List training runs with pagination support.
|
||||
|
||||
Args:
|
||||
limit: Maximum number of training runs to return (default 20)
|
||||
offset: Offset for pagination (default 0)
|
||||
|
||||
Returns:
|
||||
A Future containing the TrainingRunsResponse with training runs and cursor info
|
||||
|
||||
Example:
|
||||
>>> future = rest_client.list_training_runs(limit=50)
|
||||
>>> response = future.result()
|
||||
>>> print(f"Found {len(response.training_runs)} training runs")
|
||||
>>> print(f"Total: {response.cursor.total_count}")
|
||||
>>> # Get next page
|
||||
>>> next_page = rest_client.list_training_runs(limit=50, offset=50)
|
||||
"""
|
||||
return self._list_training_runs_submit(limit, offset).future()
|
||||
|
||||
@capture_exceptions(fatal=True)
|
||||
async def list_training_runs_async(
|
||||
self, limit: int = 20, offset: int = 0
|
||||
) -> types.TrainingRunsResponse:
|
||||
"""Async version of list_training_runs.
|
||||
|
||||
Args:
|
||||
limit: Maximum number of training runs to return (default 20)
|
||||
offset: Offset for pagination (default 0)
|
||||
|
||||
Returns:
|
||||
TrainingRunsResponse with training runs and cursor info
|
||||
|
||||
Example:
|
||||
>>> response = await rest_client.list_training_runs_async(limit=50)
|
||||
>>> print(f"Found {len(response.training_runs)} training runs")
|
||||
>>> print(f"Total: {response.cursor.total_count}")
|
||||
>>> # Get next page
|
||||
>>> next_page = await rest_client.list_training_runs_async(limit=50, offset=50)
|
||||
"""
|
||||
return await self._list_training_runs_submit(limit, offset)
|
||||
|
||||
def _list_checkpoints_submit(
|
||||
self, training_run_id: types.ModelID
|
||||
) -> AwaitableConcurrentFuture[types.CheckpointsListResponse]:
|
||||
"""Internal method to submit list model checkpoints request."""
|
||||
async def _list_checkpoints_async():
|
||||
async def _send_request():
|
||||
with self.holder.aclient(ClientConnectionPoolType.TRAIN) as client:
|
||||
return await client.weights.list(training_run_id)
|
||||
|
||||
return await self.holder.execute_with_retries(_send_request)
|
||||
|
||||
return self.holder.run_coroutine_threadsafe(_list_checkpoints_async())
|
||||
|
||||
@sync_only
|
||||
@capture_exceptions(fatal=True)
|
||||
def list_checkpoints(self, training_run_id: types.ModelID) -> ConcurrentFuture[types.CheckpointsListResponse]:
|
||||
"""List available checkpoints (both training and sampler).
|
||||
|
||||
Args:
|
||||
training_run_id: The training run ID to list checkpoints for
|
||||
|
||||
Returns:
|
||||
A Future containing the CheckpointsListResponse with available checkpoints
|
||||
|
||||
Example:
|
||||
>>> future = rest_client.list_checkpoints("run-id")
|
||||
>>> response = future.result()
|
||||
>>> for checkpoint in response.checkpoints:
|
||||
... if checkpoint.checkpoint_type == "training":
|
||||
... print(f"Training checkpoint: {checkpoint.checkpoint_id}")
|
||||
... elif checkpoint.checkpoint_type == "sampler":
|
||||
... print(f"Sampler checkpoint: {checkpoint.checkpoint_id}")
|
||||
"""
|
||||
return self._list_checkpoints_submit(training_run_id).future()
|
||||
|
||||
@capture_exceptions(fatal=True)
|
||||
async def list_checkpoints_async(self, training_run_id: types.ModelID) -> types.CheckpointsListResponse:
|
||||
"""Async version of list_checkpoints.
|
||||
|
||||
Args:
|
||||
training_run_id: The training run ID to list checkpoints for
|
||||
|
||||
Returns:
|
||||
CheckpointsListResponse with available checkpoints
|
||||
|
||||
Example:
|
||||
>>> response = await rest_client.list_checkpoints_async("run-id")
|
||||
>>> for checkpoint in response.checkpoints:
|
||||
... if checkpoint.checkpoint_type == "training":
|
||||
... print(f"Training checkpoint: {checkpoint.checkpoint_id}")
|
||||
... elif checkpoint.checkpoint_type == "sampler":
|
||||
... print(f"Sampler checkpoint: {checkpoint.checkpoint_id}")
|
||||
"""
|
||||
return await self._list_checkpoints_submit(training_run_id)
|
||||
|
||||
def _download_checkpoint_archive_submit(
|
||||
self, training_run_id: types.ModelID, checkpoint_id: str
|
||||
) -> AwaitableConcurrentFuture[bytes]:
|
||||
"""Internal method to submit download checkpoint archive request."""
|
||||
async def _download_checkpoint_archive_async():
|
||||
async def _send_request():
|
||||
with self.holder.aclient(ClientConnectionPoolType.TRAIN) as client:
|
||||
return await client.get(
|
||||
f"/api/v1/training_runs/{training_run_id}/checkpoints/{checkpoint_id}/archive",
|
||||
cast_to=bytes,
|
||||
options={"headers": {"accept": "application/gzip"}},
|
||||
)
|
||||
|
||||
return await self.holder.execute_with_retries(_send_request)
|
||||
|
||||
return self.holder.run_coroutine_threadsafe(_download_checkpoint_archive_async())
|
||||
|
||||
@sync_only
|
||||
@capture_exceptions(fatal=True)
|
||||
def download_checkpoint_archive(
|
||||
self, training_run_id: types.ModelID, checkpoint_id: str
|
||||
) -> ConcurrentFuture[bytes]:
|
||||
"""Download checkpoint as a tar.gz archive.
|
||||
|
||||
Args:
|
||||
training_run_id: The training run ID to download weights for
|
||||
checkpoint_id: The checkpoint ID to download
|
||||
|
||||
Returns:
|
||||
A Future containing the archive data as bytes
|
||||
|
||||
Example:
|
||||
>>> future = rest_client.download_checkpoint_archive("run-id", "checkpoint-123")
|
||||
>>> archive_data = future.result()
|
||||
>>> with open(f"model-checkpoint.tar.gz", "wb") as f:
|
||||
... f.write(archive_data)
|
||||
"""
|
||||
return self._download_checkpoint_archive_submit(training_run_id, checkpoint_id).future()
|
||||
|
||||
@capture_exceptions(fatal=True)
|
||||
async def download_checkpoint_archive_async(
|
||||
self, training_run_id: types.ModelID, checkpoint_id: str
|
||||
) -> bytes:
|
||||
"""Async version of download_checkpoint_archive.
|
||||
|
||||
Args:
|
||||
training_run_id: The model ID to download weights for
|
||||
checkpoint_id: The checkpoint ID to download
|
||||
|
||||
Returns:
|
||||
Archive data as bytes
|
||||
|
||||
Example:
|
||||
>>> archive_data = await rest_client.download_checkpoint_archive_async("run-id", "checkpoint-123")
|
||||
>>> with open(f"model-checkpoint.tar.gz", "wb") as f:
|
||||
... f.write(archive_data)
|
||||
"""
|
||||
return await self._download_checkpoint_archive_submit(training_run_id, checkpoint_id)
|
||||
|
||||
def _delete_checkpoint_submit(
|
||||
self, training_run_id: types.ModelID, checkpoint_id: str
|
||||
) -> AwaitableConcurrentFuture[None]:
|
||||
"""Internal method to submit delete checkpoint request."""
|
||||
|
||||
async def _delete_checkpoint_async() -> None:
|
||||
async def _send_request() -> None:
|
||||
with self.holder.aclient(ClientConnectionPoolType.TRAIN) as client:
|
||||
await client.delete(
|
||||
f"/api/v1/training_runs/{training_run_id}/checkpoints/{checkpoint_id}",
|
||||
cast_to=NoneType,
|
||||
)
|
||||
|
||||
return await self.holder.execute_with_retries(_send_request)
|
||||
|
||||
return self.holder.run_coroutine_threadsafe(_delete_checkpoint_async())
|
||||
|
||||
@sync_only
|
||||
@capture_exceptions(fatal=True)
|
||||
def delete_checkpoint(self, training_run_id: types.ModelID, checkpoint_id: str) -> ConcurrentFuture[None]:
|
||||
"""Delete a checkpoint for a training run."""
|
||||
|
||||
return self._delete_checkpoint_submit(training_run_id, checkpoint_id).future()
|
||||
|
||||
@capture_exceptions(fatal=True)
|
||||
async def delete_checkpoint_async(self, training_run_id: types.ModelID, checkpoint_id: str) -> None:
|
||||
"""Async version of delete_checkpoint."""
|
||||
|
||||
await self._delete_checkpoint_submit(training_run_id, checkpoint_id)
|
||||
|
||||
@sync_only
|
||||
@capture_exceptions(fatal=True)
|
||||
def delete_checkpoint_from_tinker_path(self, tinker_path: str) -> ConcurrentFuture[None]:
|
||||
"""Delete a checkpoint referenced by a tinker path."""
|
||||
|
||||
parsed_tinker_path = types.ParsedCheckpointTinkerPath.from_tinker_path(tinker_path)
|
||||
return self._delete_checkpoint_submit(parsed_tinker_path.training_run_id, parsed_tinker_path.checkpoint_id).future()
|
||||
|
||||
@capture_exceptions(fatal=True)
|
||||
async def delete_checkpoint_from_tinker_path_async(self, tinker_path: str) -> None:
|
||||
"""Async version of delete_checkpoint_from_tinker_path."""
|
||||
|
||||
parsed_tinker_path = types.ParsedCheckpointTinkerPath.from_tinker_path(tinker_path)
|
||||
await self._delete_checkpoint_submit(parsed_tinker_path.training_run_id, parsed_tinker_path.checkpoint_id)
|
||||
|
||||
def get_telemetry(self) -> Telemetry | None:
|
||||
return self.holder.get_telemetry()
|
||||
|
||||
@sync_only
|
||||
@capture_exceptions(fatal=True)
|
||||
def download_checkpoint_archive_from_tinker_path(
|
||||
self, tinker_path: str
|
||||
) -> ConcurrentFuture[bytes]:
|
||||
"""Download checkpoint as a tar.gz archive.
|
||||
|
||||
Args:
|
||||
tinker_path: The tinker path to the checkpoint
|
||||
|
||||
Returns:
|
||||
A Future containing the archive data as bytes
|
||||
"""
|
||||
parsed_tinker_path = types.ParsedCheckpointTinkerPath.from_tinker_path(tinker_path)
|
||||
return self._download_checkpoint_archive_submit(parsed_tinker_path.training_run_id, parsed_tinker_path.checkpoint_id).future()
|
||||
|
||||
@capture_exceptions(fatal=True)
|
||||
async def download_checkpoint_archive_from_tinker_path_async(
|
||||
self, tinker_path: str
|
||||
) -> bytes:
|
||||
"""Async version of download_checkpoint_archive_from_tinker_path.
|
||||
|
||||
Args:
|
||||
tinker_path: The tinker path to the checkpoint
|
||||
"""
|
||||
parsed_tinker_path = types.ParsedCheckpointTinkerPath.from_tinker_path(tinker_path)
|
||||
return await self._download_checkpoint_archive_submit(parsed_tinker_path.training_run_id, parsed_tinker_path.checkpoint_id)
|
||||
231
src/tinker/lib/public_interfaces/sampling_client.py
Normal file
231
src/tinker/lib/public_interfaces/sampling_client.py
Normal file
|
|
@ -0,0 +1,231 @@
|
|||
"""SamplingClient for Tinker API."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from collections.abc import Sequence
|
||||
from concurrent.futures import Future as ConcurrentFuture
|
||||
from functools import lru_cache
|
||||
from typing import TYPE_CHECKING, TypeVar, cast
|
||||
|
||||
import tinker
|
||||
from tinker import types
|
||||
from tinker._types import NOT_GIVEN
|
||||
from tinker.lib.client_connection_pool_type import ClientConnectionPoolType
|
||||
from tinker.lib.public_interfaces.api_future import AwaitableConcurrentFuture
|
||||
from tinker.lib.telemetry import Telemetry, capture_exceptions
|
||||
from tinker.lib.telemetry_provider import TelemetryProvider
|
||||
|
||||
from ..api_future_impl import QueueState, QueueStateObserver, _APIFuture
|
||||
from ..retry_handler import RetryConfig, RetryHandler
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..internal_client_holder import InternalClientHolder
|
||||
|
||||
# pyright: reportPrivateImportUsage=false
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
U = TypeVar("U")
|
||||
|
||||
|
||||
class SamplingClient(TelemetryProvider, QueueStateObserver):
|
||||
"""Client for text generation and inference from trained or base models.
|
||||
|
||||
The SamplingClient lets you generate text tokens from either a base model or from weights
|
||||
you've saved using a TrainingClient. You typically get one by calling
|
||||
`service_client.create_sampling_client()` or `training_client.save_weights_and_get_sampling_client()`.
|
||||
Key methods:
|
||||
- sample() - generate text completions with customizable parameters
|
||||
- compute_logprobs() - get log probabilities for prompt tokens
|
||||
|
||||
Args:
|
||||
holder: Internal client managing HTTP connections and async operations
|
||||
model_path: Path to saved model weights (starts with 'tinker://')
|
||||
base_model: Name of base model to use for inference
|
||||
retry_config: Configuration for retrying failed requests
|
||||
|
||||
Example:
|
||||
>>> sampling_client = service_client.create_sampling_client(base_model="Qwen/Qwen2.5-7B")
|
||||
>>> prompt = types.ModelInput.from_ints(tokenizer.encode("The weather today is"))
|
||||
>>> params = types.SamplingParams(max_tokens=20, temperature=0.7)
|
||||
>>> future = sampling_client.sample(prompt=prompt, sampling_params=params, num_samples=1)
|
||||
>>> result = future.result()
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
holder: InternalClientHolder,
|
||||
*,
|
||||
model_path: str | None = None,
|
||||
base_model: str | None = None,
|
||||
retry_config: RetryConfig | None = None,
|
||||
):
|
||||
if model_path and not model_path.startswith("tinker://"):
|
||||
raise ValueError("model_path must start with 'tinker://'")
|
||||
|
||||
self.holder = holder
|
||||
self.model_path = model_path
|
||||
self.base_model = base_model
|
||||
|
||||
# Create retry handler with the provided configuration
|
||||
self.retry_handler = _get_retry_handler(
|
||||
model_path or base_model, retry_config=retry_config, telemetry=holder.get_telemetry()
|
||||
)
|
||||
|
||||
self.feature_gates = set(
|
||||
os.environ.get("TINKER_FEATURE_GATES", "async_sampling").split(",")
|
||||
)
|
||||
|
||||
self._last_queue_state_logged: float = 0
|
||||
|
||||
async def _send_asample_request(
|
||||
self,
|
||||
num_samples: int,
|
||||
prompt: types.ModelInput,
|
||||
sampling_params: types.SamplingParams,
|
||||
include_prompt_logprobs: bool,
|
||||
idempotency_key: str,
|
||||
):
|
||||
try:
|
||||
with self.holder.aclient(ClientConnectionPoolType.SAMPLE) as client:
|
||||
return await client.sampling.asample(
|
||||
num_samples=num_samples,
|
||||
prompt=cast(types._ModelInputParam, prompt.model_dump()),
|
||||
sampling_params=cast(types._SamplingParamsParam, sampling_params.model_dump()),
|
||||
model_path=self.model_path if self.model_path is not None else NOT_GIVEN,
|
||||
prompt_logprobs=include_prompt_logprobs,
|
||||
base_model=self.base_model if self.base_model is not None else NOT_GIVEN,
|
||||
max_retries=0,
|
||||
extra_headers={"X-Tinker-Sampling-Backpressure": "1"},
|
||||
idempotency_key=idempotency_key,
|
||||
)
|
||||
except tinker.APIStatusError as e:
|
||||
if e.status_code == 429:
|
||||
return None
|
||||
raise e
|
||||
|
||||
async def _sample_async_impl(
|
||||
self,
|
||||
prompt: types.ModelInput,
|
||||
num_samples: int,
|
||||
sampling_params: types.SamplingParams,
|
||||
include_prompt_logprobs: bool,
|
||||
) -> types.SampleResponse:
|
||||
idempotency_key = self.holder.make_idempotency_key()
|
||||
async with self.holder._sample_dispatch_semaphore:
|
||||
while True:
|
||||
if (
|
||||
self.holder._sample_backoff_until is not None
|
||||
and time.time() < self.holder._sample_backoff_until
|
||||
):
|
||||
await asyncio.sleep(1)
|
||||
continue
|
||||
|
||||
untyped_future = await self.holder.execute_with_retries(
|
||||
self._send_asample_request,
|
||||
num_samples,
|
||||
prompt,
|
||||
sampling_params,
|
||||
include_prompt_logprobs,
|
||||
idempotency_key,
|
||||
)
|
||||
if untyped_future is not None:
|
||||
break
|
||||
# Handle backoff
|
||||
self.holder._sample_backoff_until = time.time() + 1
|
||||
continue
|
||||
|
||||
return await _APIFuture(
|
||||
types.SampleResponse,
|
||||
self.holder,
|
||||
untyped_future,
|
||||
request_start_time=time.time(),
|
||||
request_type="Sample",
|
||||
queue_state_observer=self,
|
||||
).result_async()
|
||||
|
||||
@capture_exceptions(fatal=True)
|
||||
def sample(
|
||||
self,
|
||||
prompt: types.ModelInput,
|
||||
num_samples: int,
|
||||
sampling_params: types.SamplingParams,
|
||||
include_prompt_logprobs: bool = False,
|
||||
) -> ConcurrentFuture[types.SampleResponse]:
|
||||
"""Internal method that does the actual API call without retry logic."""
|
||||
|
||||
async def _sample_async():
|
||||
return await self._sample_async_impl(
|
||||
prompt, num_samples, sampling_params, include_prompt_logprobs
|
||||
)
|
||||
|
||||
@capture_exceptions(fatal=True)
|
||||
async def _sample_async_with_retries() -> types.SampleResponse:
|
||||
return await self.retry_handler.execute(_sample_async)
|
||||
|
||||
# TODO make max_tokens a required field
|
||||
return self.holder.run_coroutine_threadsafe(_sample_async_with_retries()).future()
|
||||
|
||||
async def sample_async(
|
||||
self,
|
||||
prompt: types.ModelInput,
|
||||
num_samples: int,
|
||||
sampling_params: types.SamplingParams,
|
||||
include_prompt_logprobs: bool = False,
|
||||
) -> types.SampleResponse:
|
||||
return await AwaitableConcurrentFuture(
|
||||
self.sample(prompt, num_samples, sampling_params, include_prompt_logprobs)
|
||||
)
|
||||
|
||||
@capture_exceptions(fatal=True)
|
||||
def compute_logprobs(
|
||||
self, prompt: types.ModelInput
|
||||
) -> ConcurrentFuture[Sequence[float | None]]:
|
||||
async def _compute_logprobs_async() -> Sequence[float | None]:
|
||||
sample_res = await self._sample_async_impl(
|
||||
prompt,
|
||||
num_samples=1,
|
||||
sampling_params=types.SamplingParams(max_tokens=1),
|
||||
include_prompt_logprobs=True,
|
||||
)
|
||||
return cast(list[float | None], sample_res.prompt_logprobs)
|
||||
|
||||
@capture_exceptions(fatal=True)
|
||||
async def _compute_logprobs_async_with_retries() -> Sequence[float | None]:
|
||||
return await self.retry_handler.execute(_compute_logprobs_async)
|
||||
|
||||
return self.holder.run_coroutine_threadsafe(_compute_logprobs_async_with_retries()).future()
|
||||
|
||||
async def compute_logprobs_async(self, prompt: types.ModelInput) -> Sequence[float | None]:
|
||||
return await AwaitableConcurrentFuture(self.compute_logprobs(prompt))
|
||||
|
||||
def get_telemetry(self) -> Telemetry | None:
|
||||
return self.holder.get_telemetry()
|
||||
|
||||
def on_queue_state_change(self, queue_state: QueueState) -> None:
|
||||
QUEUE_STATE_LOG_INTERVAL = 60
|
||||
if queue_state == QueueState.ACTIVE:
|
||||
return
|
||||
if time.time() - self._last_queue_state_logged < QUEUE_STATE_LOG_INTERVAL:
|
||||
return
|
||||
if queue_state == QueueState.PAUSED_RATE_LIMIT:
|
||||
reason = "concurrent LoRA rate limit hit"
|
||||
elif queue_state == QueueState.PAUSED_CAPACITY:
|
||||
reason = "out of capacity"
|
||||
else:
|
||||
reason = "unknown"
|
||||
self._last_queue_state_logged = time.time()
|
||||
|
||||
logger.warning(f"Sampling is paused for {self.model_path}. Reason: {reason}")
|
||||
|
||||
|
||||
@lru_cache(maxsize=100)
|
||||
def _get_retry_handler(
|
||||
name: str, retry_config: RetryConfig | None = None, telemetry: Telemetry | None = None
|
||||
) -> RetryHandler:
|
||||
retry_config = retry_config or RetryConfig()
|
||||
return RetryHandler(config=retry_config, name=name, telemetry=telemetry)
|
||||
247
src/tinker/lib/public_interfaces/service_client.py
Normal file
247
src/tinker/lib/public_interfaces/service_client.py
Normal file
|
|
@ -0,0 +1,247 @@
|
|||
"""ServiceClient for Tinker API."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
from tinker import types
|
||||
from tinker.lib.client_connection_pool_type import ClientConnectionPoolType
|
||||
from tinker.lib.public_interfaces.api_future import AwaitableConcurrentFuture
|
||||
from tinker.lib.telemetry import Telemetry, capture_exceptions
|
||||
from tinker.lib.telemetry_provider import TelemetryProvider
|
||||
|
||||
from ..api_future_impl import _APIFuture
|
||||
from ..internal_client_holder import InternalClientHolder
|
||||
from ..retry_handler import RetryConfig
|
||||
from ..sync_only import sync_only
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .rest_client import RestClient
|
||||
from .sampling_client import SamplingClient
|
||||
from .training_client import TrainingClient
|
||||
|
||||
# pyright: reportPrivateImportUsage=false
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ServiceClient(TelemetryProvider):
|
||||
"""The ServiceClient is the main entry point for the Tinker API. It provides methods to:
|
||||
- Query server capabilities and health status
|
||||
- Generate TrainingClient instances for model training workflows
|
||||
- Generate SamplingClient instances for text generation and inference
|
||||
- Generate RestClient instances for REST API operations like listing weights
|
||||
|
||||
Args:
|
||||
**kwargs: advanced options passed to the underlying HTTP client,
|
||||
including API keys, headers, and connection settings.
|
||||
|
||||
Example:
|
||||
>>> client = ServiceClient()
|
||||
# ^^^ near-instant
|
||||
>>> training_client = client.create_lora_training_client(base_model="Qwen/Qwen3-8B")
|
||||
# ^^^ takes a moment as we initialize the model and assign resources
|
||||
>>> sampling_client = client.create_sampling_client(base_model="Qwen/Qwen3-8B")
|
||||
# ^^^ near-instant
|
||||
>>> rest_client = client.create_rest_client()
|
||||
# ^^^ near-instant
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs: Any):
|
||||
default_headers = _get_default_headers() | kwargs.pop("default_headers", {})
|
||||
self.holder = InternalClientHolder(
|
||||
**kwargs, default_headers=default_headers, _strict_response_validation=True
|
||||
)
|
||||
|
||||
def _get_server_capabilities_submit(
|
||||
self,
|
||||
) -> AwaitableConcurrentFuture[types.GetServerCapabilitiesResponse]:
|
||||
async def _get_server_capabilities_async():
|
||||
async def _send_request():
|
||||
with self.holder.aclient(ClientConnectionPoolType.TRAIN) as client:
|
||||
return await client.service.get_server_capabilities()
|
||||
|
||||
return await self.holder.execute_with_retries(_send_request)
|
||||
|
||||
return self.holder.run_coroutine_threadsafe(_get_server_capabilities_async())
|
||||
|
||||
@sync_only
|
||||
@capture_exceptions(fatal=True)
|
||||
def get_server_capabilities(self) -> types.GetServerCapabilitiesResponse:
|
||||
return self._get_server_capabilities_submit().result()
|
||||
|
||||
@capture_exceptions(fatal=True)
|
||||
async def get_server_capabilities_async(self) -> types.GetServerCapabilitiesResponse:
|
||||
return await self._get_server_capabilities_submit()
|
||||
|
||||
def _create_model_submit(
|
||||
self, base_model: str, lora_config: types.LoraConfig
|
||||
) -> AwaitableConcurrentFuture[types.ModelID]:
|
||||
async def _create_model_async():
|
||||
start_time = time.time()
|
||||
with self.holder.aclient(ClientConnectionPoolType.TRAIN) as client:
|
||||
future = await client.models.create(
|
||||
base_model=base_model, lora_config=_to_lora_config_params(lora_config)
|
||||
)
|
||||
create_model_response = await _APIFuture(
|
||||
types.CreateModelResponse,
|
||||
self.holder,
|
||||
future,
|
||||
request_start_time=start_time,
|
||||
request_type="CreateModel",
|
||||
).result_async()
|
||||
return create_model_response.model_id
|
||||
|
||||
return self.holder.run_coroutine_threadsafe(_create_model_async())
|
||||
|
||||
@sync_only
|
||||
@capture_exceptions(fatal=True)
|
||||
def create_lora_training_client(
|
||||
self,
|
||||
base_model: str,
|
||||
rank: int = 32,
|
||||
seed: int | None = None,
|
||||
train_mlp: bool = True,
|
||||
train_attn: bool = True,
|
||||
train_unembed: bool = True,
|
||||
) -> TrainingClient:
|
||||
assert any([train_mlp, train_attn, train_unembed]), (
|
||||
"At least one of train_mlp, train_attn, or train_unembed must be True"
|
||||
)
|
||||
model_id = self._create_model_submit(
|
||||
base_model,
|
||||
types.LoraConfig(
|
||||
rank=rank,
|
||||
seed=seed,
|
||||
train_mlp=train_mlp,
|
||||
train_attn=train_attn,
|
||||
train_unembed=train_unembed,
|
||||
),
|
||||
).result()
|
||||
logger.info(f"Creating TrainingClient for {model_id=}")
|
||||
return self.create_training_client(model_id)
|
||||
|
||||
@capture_exceptions(fatal=True)
|
||||
async def create_lora_training_client_async(
|
||||
self,
|
||||
base_model: str,
|
||||
rank: int = 32,
|
||||
seed: int | None = None,
|
||||
train_mlp: bool = True,
|
||||
train_attn: bool = True,
|
||||
train_unembed: bool = True,
|
||||
) -> TrainingClient:
|
||||
assert any([train_mlp, train_attn, train_unembed]), (
|
||||
"At least one of train_mlp, train_attn, or train_unembed must be True"
|
||||
)
|
||||
model_id = await self._create_model_submit(
|
||||
base_model,
|
||||
types.LoraConfig(
|
||||
rank=rank,
|
||||
seed=seed,
|
||||
train_mlp=train_mlp,
|
||||
train_attn=train_attn,
|
||||
train_unembed=train_unembed,
|
||||
),
|
||||
)
|
||||
logger.info(f"Creating TrainingClient for {model_id=}")
|
||||
return self.create_training_client(model_id)
|
||||
|
||||
@capture_exceptions(fatal=True)
|
||||
def create_training_client(self, model_id: types.ModelID | None = None) -> TrainingClient:
|
||||
from .training_client import TrainingClient
|
||||
|
||||
return TrainingClient(self.holder, model_id=model_id)
|
||||
|
||||
@sync_only
|
||||
@capture_exceptions(fatal=True)
|
||||
def create_training_client_from_state(self, path: str) -> TrainingClient:
|
||||
rest_client = self.create_rest_client()
|
||||
training_run = rest_client.get_training_run_by_tinker_path(path).result()
|
||||
|
||||
training_client = self.create_lora_training_client(
|
||||
base_model=training_run.base_model,
|
||||
rank=training_run.lora_rank,
|
||||
)
|
||||
|
||||
training_client.load_state(path).result()
|
||||
return training_client
|
||||
|
||||
@capture_exceptions(fatal=True)
|
||||
async def create_training_client_from_state_async(self, path: str) -> TrainingClient:
|
||||
rest_client = self.create_rest_client()
|
||||
training_run = await rest_client.get_training_run_by_tinker_path_async(path)
|
||||
|
||||
# Right now all training runs are LoRa runs.
|
||||
assert training_run.is_lora and training_run.lora_rank is not None
|
||||
|
||||
training_client = await self.create_lora_training_client_async(
|
||||
base_model=training_run.base_model,
|
||||
rank=training_run.lora_rank,
|
||||
)
|
||||
|
||||
load_future = await training_client.load_state_async(path)
|
||||
await load_future.result_async()
|
||||
return training_client
|
||||
|
||||
@capture_exceptions(fatal=True)
|
||||
def create_sampling_client(
|
||||
self,
|
||||
model_path: str | None = None,
|
||||
base_model: str | None = None,
|
||||
retry_config: RetryConfig | None = None,
|
||||
) -> SamplingClient:
|
||||
from .sampling_client import SamplingClient
|
||||
|
||||
if model_path is None and base_model is None:
|
||||
raise ValueError("Either model_path or base_model must be provided")
|
||||
return SamplingClient(
|
||||
self.holder,
|
||||
model_path=model_path,
|
||||
base_model=base_model,
|
||||
retry_config=retry_config,
|
||||
)
|
||||
|
||||
@capture_exceptions(fatal=True)
|
||||
def create_rest_client(self) -> RestClient:
|
||||
"""Create a RestClient for REST API operations.
|
||||
|
||||
Returns:
|
||||
RestClient: A client for listing weights and other REST operations
|
||||
|
||||
Example:
|
||||
>>> rest_client = service_client.create_rest_client()
|
||||
>>> weights = rest_client.list_model_weights("my-model-id").result()
|
||||
"""
|
||||
from .rest_client import RestClient
|
||||
|
||||
return RestClient(self.holder)
|
||||
|
||||
def get_telemetry(self) -> Telemetry | None:
|
||||
return self.holder.get_telemetry()
|
||||
|
||||
|
||||
def _get_default_headers() -> dict[str, str]:
|
||||
headers = {}
|
||||
|
||||
if (api_key := os.environ.get("TINKER_API_KEY", "")) and "X-API-Key" not in headers:
|
||||
headers["X-API-Key"] = api_key
|
||||
|
||||
headers["X-Username"] = os.environ.get("USER", "")
|
||||
|
||||
if (
|
||||
client_id := os.environ.get("CLOUDFLARE_ACCESS_CLIENT_ID")
|
||||
) and "CF-Access-Client-Id" not in headers:
|
||||
headers["CF-Access-Client-Id"] = client_id
|
||||
if (
|
||||
client_secret := os.environ.get("CLOUDFLARE_ACCESS_CLIENT_SECRET")
|
||||
) and "CF-Access-Client-Secret" not in headers:
|
||||
headers["CF-Access-Client-Secret"] = client_secret
|
||||
return headers
|
||||
|
||||
|
||||
def _to_lora_config_params(x: types.LoraConfig) -> types._LoraConfigParam:
|
||||
return cast(types._LoraConfigParam, x.model_dump())
|
||||
585
src/tinker/lib/public_interfaces/training_client.py
Normal file
585
src/tinker/lib/public_interfaces/training_client.py
Normal file
|
|
@ -0,0 +1,585 @@
|
|||
"""TrainingClient for Tinker API."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from contextlib import asynccontextmanager
|
||||
from functools import cache
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, List, Tuple, cast
|
||||
|
||||
import torch
|
||||
|
||||
from tinker import types
|
||||
from tinker.lib.client_connection_pool_type import ClientConnectionPoolType
|
||||
from tinker.lib.public_interfaces.api_future import APIFuture, AwaitableConcurrentFuture
|
||||
from tinker.lib.telemetry import Telemetry, capture_exceptions
|
||||
from tinker.lib.telemetry_provider import TelemetryProvider
|
||||
from tinker.types import training_optim_step_params
|
||||
|
||||
from ..api_future_impl import (
|
||||
QueueState,
|
||||
QueueStateObserver,
|
||||
_APIFuture,
|
||||
_CombinedAPIFuture,
|
||||
)
|
||||
from ..chunked_fwdbwd_helpers import combine_fwd_bwd_output_results
|
||||
from ..retry_handler import RetryConfig
|
||||
from ..sync_only import sync_only
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
|
||||
from ..internal_client_holder import InternalClientHolder
|
||||
from .sampling_client import SamplingClient
|
||||
|
||||
# pyright: reportPrivateImportUsage=false
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# FwdBwdChunkSize
|
||||
MAX_CHUNK_LEN = 128
|
||||
MAX_CHUNK_NUMBER_COUNT = 500000
|
||||
MODEL_ID_NOT_SET_ERROR = "model_id must be set before calling forward. Try initializing the TrainingClient with a model_id by either calling create_lora_training_client on the ServiceClient, or initiliazing the TrainingClient with an existing model_id."
|
||||
|
||||
CustomLossFnV1 = Callable[[List[types.Datum], List[Any]], Tuple[Any, Dict[str, float]]]
|
||||
|
||||
|
||||
class TrainingClient(TelemetryProvider, QueueStateObserver):
|
||||
"""Client for training ML models with forward/backward passes and optimization.
|
||||
|
||||
The TrainingClient corresponds to a fine-tuned model that you can train and sample from.
|
||||
You typically get one by calling `service_client.create_lora_training_client()`.
|
||||
Key methods:
|
||||
- forward_backward() - compute gradients for training
|
||||
- optim_step() - update model parameters with Adam optimizer
|
||||
- save_weights_and_get_sampling_client() - export trained model for inference
|
||||
|
||||
Args:
|
||||
holder: Internal client managing HTTP connections and async operations
|
||||
model_id: Unique identifier for the model to train. Required for training operations.
|
||||
|
||||
Example:
|
||||
>>> training_client = service_client.create_lora_training_client(base_model="Qwen/Qwen2.5-7B")
|
||||
>>> fwdbwd_future = training_client.forward_backward(training_data, "cross_entropy")
|
||||
>>> optim_future = training_client.optim_step(types.AdamParams(learning_rate=1e-4))
|
||||
>>> fwdbwd_result = fwdbwd_future.result() # Wait for gradients
|
||||
>>> optim_result = optim_future.result() # Wait for parameter update
|
||||
>>> sampling_client = training_client.save_weights_and_get_sampling_client("my-model")
|
||||
"""
|
||||
|
||||
def __init__(self, holder: InternalClientHolder, model_id: types.ModelID | None = None):
|
||||
self.holder = holder
|
||||
self.model_id = model_id
|
||||
|
||||
self._training_client_id: int = self.holder.get_training_client_id()
|
||||
|
||||
self._request_id_lock: threading.Lock = threading.Lock()
|
||||
self._request_id_counter: int = 0
|
||||
|
||||
self._turn_counter: int = 0
|
||||
self._turn_waiters: dict[int, asyncio.Event] = {}
|
||||
|
||||
self._last_queue_state_logged: float = 0
|
||||
|
||||
# Reserves a request id for a request. Requests are to be executed in the order of request ids.
|
||||
def _get_request_id(self) -> int:
|
||||
with self._request_id_lock:
|
||||
request_id = self._request_id_counter
|
||||
self._request_id_counter += 1
|
||||
return request_id
|
||||
|
||||
# Waits for the turn for a given request id to be executed.
|
||||
# This has to be used via a with statement so that the turn is released
|
||||
# only after current request was successfully dispatched.
|
||||
@asynccontextmanager
|
||||
async def _take_turn(self, request_id: int):
|
||||
assert self._turn_counter <= request_id, "Same request id cannot be taken twice"
|
||||
|
||||
if self._turn_counter < request_id:
|
||||
try:
|
||||
event = asyncio.Event()
|
||||
self._turn_waiters[request_id] = event
|
||||
await event.wait()
|
||||
finally:
|
||||
del self._turn_waiters[request_id]
|
||||
|
||||
assert self._turn_counter == request_id
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self._turn_counter += 1
|
||||
if self._turn_counter in self._turn_waiters:
|
||||
self._turn_waiters[self._turn_counter].set()
|
||||
|
||||
def _make_idempotency_key(self, request_id: int) -> str:
|
||||
return self.holder.make_training_client_idempotency_key(
|
||||
self._training_client_id, request_id
|
||||
)
|
||||
|
||||
def _guaranteed_model_id(self) -> types.ModelID:
|
||||
assert self.model_id is not None, MODEL_ID_NOT_SET_ERROR
|
||||
return self.model_id
|
||||
|
||||
def _estimate_number_count(self, datum: types.Datum) -> int:
|
||||
return datum.model_input.length + sum(
|
||||
len(value.data) for _, value in datum.loss_fn_inputs.items()
|
||||
)
|
||||
|
||||
def _chunked_requests_generator(
|
||||
self, data: List[types.Datum]
|
||||
) -> Generator[List[types.Datum], None, None]:
|
||||
current_chunk: List[types.Datum] = []
|
||||
current_chunk_number_count = 0
|
||||
|
||||
for datum in data:
|
||||
estimated_number_count = self._estimate_number_count(datum)
|
||||
if (
|
||||
len(current_chunk) > 0
|
||||
and current_chunk_number_count + estimated_number_count > MAX_CHUNK_NUMBER_COUNT
|
||||
) or (len(current_chunk) == MAX_CHUNK_LEN):
|
||||
yield current_chunk
|
||||
current_chunk = []
|
||||
current_chunk_number_count = 0
|
||||
|
||||
current_chunk.append(datum)
|
||||
current_chunk_number_count += estimated_number_count
|
||||
|
||||
if len(current_chunk) > 0:
|
||||
yield current_chunk
|
||||
|
||||
def _chunked_requests(self, data: List[types.Datum]) -> List[tuple[int, List[types.Datum]]]:
|
||||
return [(self._get_request_id(), chunk) for chunk in self._chunked_requests_generator(data)]
|
||||
|
||||
async def _send_single_forward_request(
|
||||
self, request_id: int, data: List[types.Datum], loss_fn: types.LossFnType
|
||||
):
|
||||
with self.holder.aclient(ClientConnectionPoolType.TRAIN) as client:
|
||||
return await client.training.forward(
|
||||
model_id=self._guaranteed_model_id(),
|
||||
forward_input=_to_fwdbwd_input_params(
|
||||
types.ForwardBackwardInput(data=data, loss_fn=loss_fn)
|
||||
),
|
||||
idempotency_key=self._make_idempotency_key(request_id),
|
||||
)
|
||||
|
||||
@capture_exceptions(fatal=True)
|
||||
def forward(
|
||||
self, data: List[types.Datum], loss_fn: types.LossFnType
|
||||
) -> APIFuture[types.ForwardBackwardOutput]:
|
||||
requests = self._chunked_requests(data)
|
||||
|
||||
@capture_exceptions(fatal=True)
|
||||
async def _forward_async():
|
||||
start_time = time.time()
|
||||
futures = []
|
||||
for request_id, data in requests:
|
||||
async with self._take_turn(request_id):
|
||||
untyped_future = await self.holder.execute_with_retries(
|
||||
self._send_single_forward_request, request_id, data, loss_fn
|
||||
)
|
||||
api_future = _APIFuture(
|
||||
types.ForwardBackwardOutput,
|
||||
self.holder,
|
||||
untyped_future,
|
||||
request_start_time=start_time,
|
||||
request_type="Forward",
|
||||
queue_state_observer=self,
|
||||
)
|
||||
futures.append(api_future)
|
||||
return await _CombinedAPIFuture(futures, combine_fwd_bwd_output_results, self.holder)
|
||||
|
||||
return self.holder.run_coroutine_threadsafe(_forward_async())
|
||||
|
||||
async def forward_async(
|
||||
self, data: List[types.Datum], loss_fn: types.LossFnType
|
||||
) -> APIFuture[types.ForwardBackwardOutput]:
|
||||
return self.forward(data, loss_fn)
|
||||
|
||||
async def _send_single_forward_backward_request(
|
||||
self, request_id: int, data: List[types.Datum], loss_fn: types.LossFnType
|
||||
):
|
||||
with self.holder.aclient(ClientConnectionPoolType.TRAIN) as client:
|
||||
return await client.training.forward_backward(
|
||||
model_id=self._guaranteed_model_id(),
|
||||
forward_backward_input=_to_fwdbwd_input_params(
|
||||
types.ForwardBackwardInput(data=data, loss_fn=loss_fn)
|
||||
),
|
||||
idempotency_key=self._make_idempotency_key(request_id),
|
||||
)
|
||||
|
||||
@capture_exceptions(fatal=True)
|
||||
def forward_backward(
|
||||
self, data: List[types.Datum], loss_fn: types.LossFnType
|
||||
) -> APIFuture[types.ForwardBackwardOutput]:
|
||||
requests = self._chunked_requests(data)
|
||||
|
||||
@capture_exceptions(fatal=True)
|
||||
async def _forward_backward_async():
|
||||
futures = []
|
||||
start_time = time.time()
|
||||
|
||||
for request_id, data in requests:
|
||||
async with self._take_turn(request_id):
|
||||
untyped_future = await self.holder.execute_with_retries(
|
||||
self._send_single_forward_backward_request, request_id, data, loss_fn
|
||||
)
|
||||
api_future = _APIFuture(
|
||||
types.ForwardBackwardOutput,
|
||||
self.holder,
|
||||
untyped_future,
|
||||
request_start_time=start_time,
|
||||
request_type="ForwardBackward",
|
||||
queue_state_observer=self,
|
||||
)
|
||||
futures.append(api_future)
|
||||
|
||||
return await _CombinedAPIFuture(futures, combine_fwd_bwd_output_results, self.holder)
|
||||
|
||||
return self.holder.run_coroutine_threadsafe(_forward_backward_async())
|
||||
|
||||
async def forward_backward_async(
|
||||
self, data: List[types.Datum], loss_fn: types.LossFnType
|
||||
) -> APIFuture[types.ForwardBackwardOutput]:
|
||||
return self.forward_backward(data, loss_fn)
|
||||
|
||||
@sync_only
|
||||
@capture_exceptions(fatal=True)
|
||||
def forward_backward_custom(
|
||||
self, data: List[types.Datum], loss_fn: CustomLossFnV1
|
||||
) -> APIFuture[types.ForwardBackwardOutput]:
|
||||
"""Synchronous version of forward_backward_custom_async."""
|
||||
return self.holder.run_coroutine_threadsafe(
|
||||
self.forward_backward_custom_async(data, loss_fn)
|
||||
).result()
|
||||
|
||||
@capture_exceptions(fatal=True)
|
||||
async def forward_backward_custom_async(
|
||||
self, data: List[types.Datum], loss_fn: CustomLossFnV1
|
||||
) -> APIFuture[types.ForwardBackwardOutput]:
|
||||
import torch
|
||||
|
||||
# First do a forward pass and get logprobs
|
||||
forward_future = await self.forward_async(data, "cross_entropy")
|
||||
forward_result = await forward_future.result_async()
|
||||
logprobs_list: List[torch.Tensor] = []
|
||||
for out in forward_result.loss_fn_outputs:
|
||||
logprob = torch.tensor(out["logprobs"].data).clone().detach().requires_grad_(True)
|
||||
logprobs_list.append(logprob)
|
||||
|
||||
# Now apply user-provided function
|
||||
loss, metrics = loss_fn(data, logprobs_list)
|
||||
loss.backward()
|
||||
grads = []
|
||||
for logprob in logprobs_list:
|
||||
if logprob.grad is None:
|
||||
raise ValueError("No gradient computed for logprob tensor")
|
||||
grads.append(logprob.grad)
|
||||
|
||||
linear_loss_data = []
|
||||
for datum, grad in zip(data, grads):
|
||||
loss_fn_inputs: Any = {
|
||||
"target_tokens": datum.loss_fn_inputs["target_tokens"],
|
||||
"weights": -grad, # Pass PyTorch tensor directly (will be converted to TensorData)
|
||||
}
|
||||
linear_loss_data.append(
|
||||
types.Datum(
|
||||
model_input=datum.model_input,
|
||||
loss_fn_inputs=loss_fn_inputs,
|
||||
)
|
||||
)
|
||||
|
||||
# Do the backward pass with the gradients
|
||||
backward_future = await self.forward_backward_async(linear_loss_data, "cross_entropy")
|
||||
|
||||
# We need to slightly modify the future to add the custom metrics, so we use _CombinedAPIFuture
|
||||
# to transform the future.
|
||||
def add_custom_metrics(
|
||||
results: List[types.ForwardBackwardOutput],
|
||||
) -> types.ForwardBackwardOutput:
|
||||
result = results[0] # Single result
|
||||
result.metrics.update(metrics)
|
||||
return result
|
||||
|
||||
return _CombinedAPIFuture([backward_future], add_custom_metrics, self.holder)
|
||||
|
||||
@capture_exceptions(fatal=True)
|
||||
def optim_step(self, adam_params: types.AdamParams) -> APIFuture[types.OptimStepResponse]:
|
||||
request_id = self._get_request_id()
|
||||
|
||||
@capture_exceptions(fatal=True)
|
||||
async def _optim_step_async():
|
||||
start_time = time.time()
|
||||
|
||||
async def _send_request():
|
||||
with self.holder.aclient(ClientConnectionPoolType.TRAIN) as client:
|
||||
return await client.training.optim_step(
|
||||
model_id=self._guaranteed_model_id(),
|
||||
adam_params=_to_adam_params(adam_params),
|
||||
idempotency_key=self._make_idempotency_key(request_id),
|
||||
)
|
||||
|
||||
async with self._take_turn(request_id):
|
||||
untyped_future = await self.holder.execute_with_retries(_send_request)
|
||||
return await _APIFuture(
|
||||
types.OptimStepResponse,
|
||||
self.holder,
|
||||
untyped_future,
|
||||
request_start_time=start_time,
|
||||
request_type="OptimStep",
|
||||
queue_state_observer=self,
|
||||
)
|
||||
|
||||
return self.holder.run_coroutine_threadsafe(_optim_step_async())
|
||||
|
||||
async def optim_step_async(
|
||||
self, adam_params: types.AdamParams
|
||||
) -> APIFuture[types.OptimStepResponse]:
|
||||
return self.optim_step(adam_params)
|
||||
|
||||
@capture_exceptions(fatal=True)
|
||||
def save_state(self, name: str) -> APIFuture[types.SaveWeightsResponse]:
|
||||
request_id = self._get_request_id()
|
||||
|
||||
@capture_exceptions(fatal=True)
|
||||
async def _save_state_async():
|
||||
start_time = time.time()
|
||||
|
||||
async def _send_request():
|
||||
with self.holder.aclient(ClientConnectionPoolType.TRAIN) as client:
|
||||
return await client.weights.save(
|
||||
model_id=self._guaranteed_model_id(),
|
||||
path=name,
|
||||
idempotency_key=self._make_idempotency_key(request_id),
|
||||
)
|
||||
|
||||
async with self._take_turn(request_id):
|
||||
future = await self.holder.execute_with_retries(_send_request)
|
||||
return await _APIFuture(
|
||||
types.SaveWeightsResponse,
|
||||
self.holder,
|
||||
future,
|
||||
request_start_time=start_time,
|
||||
request_type="SaveWeights",
|
||||
queue_state_observer=self,
|
||||
)
|
||||
|
||||
return self.holder.run_coroutine_threadsafe(_save_state_async())
|
||||
|
||||
async def save_state_async(self, name: str) -> APIFuture[types.SaveWeightsResponse]:
|
||||
return self.save_state(name)
|
||||
|
||||
@capture_exceptions(fatal=True)
|
||||
def load_state(self, path: str) -> APIFuture[types.LoadWeightsResponse]:
|
||||
request_id = self._get_request_id()
|
||||
|
||||
@capture_exceptions(fatal=True)
|
||||
async def _load_state_async():
|
||||
start_time = time.time()
|
||||
|
||||
async def _send_request():
|
||||
with self.holder.aclient(ClientConnectionPoolType.TRAIN) as client:
|
||||
return await client.weights.load(
|
||||
model_id=self._guaranteed_model_id(),
|
||||
path=path,
|
||||
idempotency_key=self._make_idempotency_key(request_id),
|
||||
)
|
||||
|
||||
async with self._take_turn(request_id):
|
||||
future = await self.holder.execute_with_retries(_send_request)
|
||||
return await _APIFuture(
|
||||
types.LoadWeightsResponse,
|
||||
self.holder,
|
||||
future,
|
||||
request_start_time=start_time,
|
||||
request_type="LoadWeights",
|
||||
queue_state_observer=self,
|
||||
)
|
||||
|
||||
return self.holder.run_coroutine_threadsafe(_load_state_async())
|
||||
|
||||
async def load_state_async(self, path: str) -> APIFuture[types.LoadWeightsResponse]:
|
||||
return self.load_state(path)
|
||||
|
||||
@capture_exceptions(fatal=True)
|
||||
def save_weights_for_sampler(self, name: str) -> APIFuture[types.SaveWeightsForSamplerResponse]:
|
||||
request_id = self._get_request_id()
|
||||
|
||||
@capture_exceptions(fatal=True)
|
||||
async def _save_weights_for_sampler_async():
|
||||
start_time = time.time()
|
||||
|
||||
async def _send_request():
|
||||
with self.holder.aclient(ClientConnectionPoolType.TRAIN) as client:
|
||||
return await client.weights.save_for_sampler(
|
||||
model_id=self._guaranteed_model_id(),
|
||||
path=name,
|
||||
idempotency_key=self._make_idempotency_key(request_id),
|
||||
)
|
||||
|
||||
async with self._take_turn(request_id):
|
||||
future = await self.holder.execute_with_retries(_send_request)
|
||||
return await _APIFuture(
|
||||
types.SaveWeightsForSamplerResponse,
|
||||
self.holder,
|
||||
future,
|
||||
request_start_time=start_time,
|
||||
request_type="SaveWeightsForSampler",
|
||||
queue_state_observer=self,
|
||||
)
|
||||
|
||||
return self.holder.run_coroutine_threadsafe(_save_weights_for_sampler_async())
|
||||
|
||||
async def save_weights_for_sampler_async(
|
||||
self, name: str
|
||||
) -> APIFuture[types.SaveWeightsForSamplerResponse]:
|
||||
return self.save_weights_for_sampler(name)
|
||||
|
||||
@capture_exceptions(fatal=True)
|
||||
def unload_model(
|
||||
self,
|
||||
) -> APIFuture[types.UnloadModelResponse]:
|
||||
request_id = self._get_request_id()
|
||||
|
||||
@capture_exceptions(fatal=True)
|
||||
async def _unload_model_async():
|
||||
start_time = time.time()
|
||||
|
||||
async def _send_request():
|
||||
with self.holder.aclient(ClientConnectionPoolType.TRAIN) as client:
|
||||
return await client.models.unload(
|
||||
model_id=self._guaranteed_model_id(),
|
||||
idempotency_key=self._make_idempotency_key(request_id),
|
||||
)
|
||||
|
||||
async with self._take_turn(request_id):
|
||||
future = await self.holder.execute_with_retries(_send_request)
|
||||
return await _APIFuture(
|
||||
types.UnloadModelResponse,
|
||||
self.holder,
|
||||
future,
|
||||
request_start_time=start_time,
|
||||
request_type="UnloadModel",
|
||||
queue_state_observer=self,
|
||||
)
|
||||
|
||||
return self.holder.run_coroutine_threadsafe(_unload_model_async())
|
||||
|
||||
async def unload_model_async(self) -> APIFuture[types.UnloadModelResponse]:
|
||||
return self.unload_model()
|
||||
|
||||
def _get_info_submit(self) -> AwaitableConcurrentFuture[types.GetInfoResponse]:
|
||||
request_id = self._get_request_id()
|
||||
|
||||
async def _get_info_async():
|
||||
async def _send_request():
|
||||
with self.holder.aclient(ClientConnectionPoolType.TRAIN) as client:
|
||||
return await client.models.get_info(
|
||||
model_id=self._guaranteed_model_id(),
|
||||
idempotency_key=self._make_idempotency_key(request_id),
|
||||
)
|
||||
|
||||
async with self._take_turn(request_id):
|
||||
return await self.holder.execute_with_retries(_send_request)
|
||||
|
||||
return self.holder.run_coroutine_threadsafe(_get_info_async())
|
||||
|
||||
@sync_only
|
||||
@capture_exceptions(fatal=True)
|
||||
def get_info(self) -> types.GetInfoResponse:
|
||||
return self._get_info_submit().result()
|
||||
|
||||
@capture_exceptions(fatal=True)
|
||||
async def get_info_async(self) -> types.GetInfoResponse:
|
||||
return await self._get_info_submit()
|
||||
|
||||
@cache
|
||||
@capture_exceptions(fatal=True)
|
||||
def get_tokenizer(self) -> PreTrainedTokenizer:
|
||||
return _get_tokenizer(self._guaranteed_model_id(), self.holder)
|
||||
|
||||
@capture_exceptions(fatal=True)
|
||||
def create_sampling_client(
|
||||
self, model_path: str, retry_config: RetryConfig | None = None
|
||||
) -> SamplingClient:
|
||||
from .sampling_client import SamplingClient
|
||||
|
||||
return SamplingClient(self.holder, model_path=model_path, retry_config=retry_config)
|
||||
|
||||
@capture_exceptions(fatal=True)
|
||||
def save_weights_and_get_sampling_client(
|
||||
self, name: str, retry_config: RetryConfig | None = None
|
||||
) -> SamplingClient:
|
||||
from .sampling_client import SamplingClient
|
||||
|
||||
path = self.save_weights_for_sampler(name).result().path
|
||||
return SamplingClient(self.holder, model_path=path, retry_config=retry_config)
|
||||
|
||||
@capture_exceptions(fatal=True)
|
||||
async def save_weights_and_get_sampling_client_async(
|
||||
self, name: str, retry_config: RetryConfig | None = None
|
||||
) -> SamplingClient:
|
||||
from .sampling_client import SamplingClient
|
||||
|
||||
save_weights_future = await self.save_weights_for_sampler_async(name)
|
||||
save_weights_result = await save_weights_future.result_async()
|
||||
model_path = save_weights_result.path
|
||||
return SamplingClient(self.holder, model_path=model_path, retry_config=retry_config)
|
||||
|
||||
def get_telemetry(self) -> Telemetry | None:
|
||||
return self.holder.get_telemetry()
|
||||
|
||||
def on_queue_state_change(self, queue_state: QueueState) -> None:
|
||||
QUEUE_STATE_LOG_INTERVAL = 60
|
||||
if queue_state == QueueState.ACTIVE:
|
||||
return
|
||||
if time.time() - self._last_queue_state_logged < QUEUE_STATE_LOG_INTERVAL:
|
||||
return
|
||||
self._last_queue_state_logged = time.time()
|
||||
|
||||
if queue_state == QueueState.PAUSED_RATE_LIMIT:
|
||||
reason = "concurrent models rate limit hit"
|
||||
elif queue_state == QueueState.PAUSED_CAPACITY:
|
||||
reason = "out of capacity"
|
||||
else:
|
||||
reason = "unknown"
|
||||
logger.warning(f"Training is paused for {self.model_id}. Reason: {reason}")
|
||||
|
||||
|
||||
def _to_fwdbwd_input_params(x: types.ForwardBackwardInput) -> types._ForwardBackwardInputParam:
|
||||
return cast(types._ForwardBackwardInputParam, x.model_dump())
|
||||
|
||||
|
||||
def _to_adam_params(x: types.AdamParams) -> training_optim_step_params.AdamParams:
|
||||
return cast(training_optim_step_params.AdamParams, x.model_dump())
|
||||
|
||||
|
||||
def _get_tokenizer(model_id: types.ModelID, holder: InternalClientHolder) -> PreTrainedTokenizer:
|
||||
# call get_info on model_id
|
||||
from transformers.models.auto.tokenization_auto import AutoTokenizer
|
||||
|
||||
async def _get_info_async():
|
||||
with holder.aclient(ClientConnectionPoolType.TRAIN) as client:
|
||||
return await client.models.get_info(model_id=model_id)
|
||||
|
||||
info = holder.run_coroutine_threadsafe(_get_info_async()).result()
|
||||
model_name = info.model_data.model_name
|
||||
assert model_name is not None, "This shouldn't happen: model_name is None"
|
||||
|
||||
# We generally adhere to the huggingface convention of "<org>/<model>" but
|
||||
# in some cases we'll deploy variants using the format
|
||||
# "<org>/<model>/<variant>". In that case, we want to load the tokenizer
|
||||
# using the huggingface convention.
|
||||
if model_name.startswith("meta-llama/Llama-3"):
|
||||
# Avoid gating of Llama 3 models:
|
||||
tokenizer_id = "baseten/Meta-Llama-3-tokenizer"
|
||||
elif model_name.count("/") == 2:
|
||||
org, model, _variant = model_name.split("/", 2)
|
||||
tokenizer_id = f"{org}/{model}"
|
||||
else:
|
||||
tokenizer_id = model_name
|
||||
|
||||
return AutoTokenizer.from_pretrained(tokenizer_id, fast=True)
|
||||
278
src/tinker/lib/retry_handler.py
Normal file
278
src/tinker/lib/retry_handler.py
Normal file
|
|
@ -0,0 +1,278 @@
|
|||
"""
|
||||
Generalizable retry handler for API requests with connection limiting,
|
||||
progress tracking, and exponential backoff.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import random
|
||||
import time
|
||||
import traceback
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Awaitable, Callable, Generic, Type, TypeVar
|
||||
|
||||
import httpx
|
||||
|
||||
import tinker
|
||||
from tinker.lib.telemetry import Telemetry
|
||||
|
||||
from .._constants import (
|
||||
DEFAULT_CONNECTION_LIMITS,
|
||||
INITIAL_RETRY_DELAY,
|
||||
MAX_RETRY_DELAY,
|
||||
)
|
||||
from .retryable_exception import RetryableException
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def is_retryable_status_code(status_code: int) -> bool:
|
||||
return status_code in (408, 409, 429) or (500 <= status_code < 600)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RetryConfig:
|
||||
max_connections: int = DEFAULT_CONNECTION_LIMITS.max_connections or 100
|
||||
progress_timeout: float = 30 * 60 # Very long straggler
|
||||
retry_delay_base: float = INITIAL_RETRY_DELAY
|
||||
retry_delay_max: float = MAX_RETRY_DELAY
|
||||
jitter_factor: float = 0.25
|
||||
enable_retry_logic: bool = True
|
||||
retryable_exceptions: tuple[Type[Exception], ...] = (
|
||||
asyncio.TimeoutError,
|
||||
tinker.APIConnectionError,
|
||||
httpx.TimeoutException,
|
||||
RetryableException,
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.max_connections <= 0:
|
||||
raise ValueError(f"max_connections must be positive, got {self.max_connections}")
|
||||
|
||||
def __hash__(self):
|
||||
return hash(
|
||||
(
|
||||
self.max_connections,
|
||||
self.progress_timeout,
|
||||
self.retry_delay_base,
|
||||
self.retry_delay_max,
|
||||
self.jitter_factor,
|
||||
self.enable_retry_logic,
|
||||
self.retryable_exceptions,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class RetryHandler(Generic[T]): # noqa: UP046
|
||||
"""
|
||||
A generalizable retry handler for API requests.
|
||||
|
||||
Features:
|
||||
- Connection limiting with semaphores
|
||||
- Global progress timeout tracking
|
||||
- Exponential backoff with jitter
|
||||
- Configurable error classification
|
||||
|
||||
Usage:
|
||||
handler = RetryHandler(config=retry_config)
|
||||
result = await handler.execute(my_function, *args, **kwargs)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: RetryConfig = RetryConfig(),
|
||||
name: str = "default",
|
||||
telemetry: Telemetry | None = None,
|
||||
):
|
||||
self.config = config
|
||||
self.name = name
|
||||
self._telemetry = telemetry
|
||||
current_time = time.time()
|
||||
self._last_global_progress = current_time
|
||||
self._last_printed_progress = current_time
|
||||
self._processed_count = 0
|
||||
self._waiting_at_semaphore_count = 0
|
||||
self._in_retry_loop_count = 0
|
||||
self._retry_count = 0
|
||||
self._exception_counts = {} # Track exception types and their counts
|
||||
|
||||
self._errors_since_last_retry: defaultdict[str, int] = defaultdict(int)
|
||||
|
||||
# The semaphore is used to limit the number of concurrent requests.
|
||||
# Without a semaphore, progress can grind to a halt as requests fight
|
||||
# for limited httpx connections.
|
||||
self._semaphore = asyncio.Semaphore(config.max_connections)
|
||||
|
||||
async def execute(self, func: Callable[..., Awaitable[T]], *args: Any, **kwargs: Any) -> T:
|
||||
"""Use as a direct function call."""
|
||||
|
||||
self._waiting_at_semaphore_count += 1
|
||||
async with self._semaphore:
|
||||
self._waiting_at_semaphore_count -= 1
|
||||
if self._in_retry_loop_count == 0:
|
||||
self._last_global_progress = time.time()
|
||||
self._in_retry_loop_count += 1
|
||||
self._maybe_log_progress()
|
||||
|
||||
async def _check_progress(parent_task: asyncio.Task[T]):
|
||||
while True:
|
||||
deadline = self._last_global_progress + self.config.progress_timeout
|
||||
if time.time() > deadline:
|
||||
parent_task._no_progress_made_marker = True # pyright: ignore[reportAttributeAccessIssue]
|
||||
parent_task.cancel()
|
||||
await asyncio.sleep(deadline - time.time())
|
||||
|
||||
current_task = asyncio.current_task()
|
||||
assert current_task is not None
|
||||
current_task._no_progress_made_marker = False # pyright: ignore[reportAttributeAccessIssue]
|
||||
progress_task = asyncio.create_task(_check_progress(current_task))
|
||||
|
||||
try:
|
||||
result = await self._execute_with_retry(func, *args, **kwargs)
|
||||
self._last_global_progress = time.time()
|
||||
return result
|
||||
except asyncio.CancelledError:
|
||||
if current_task._no_progress_made_marker: # pyright: ignore[reportAttributeAccessIssue]
|
||||
current_task.uncancel()
|
||||
# Create a dummy request for the exception (required by APIConnectionError)
|
||||
dummy_request = httpx.Request("GET", "http://localhost")
|
||||
raise tinker.APIConnectionError(
|
||||
message=f"No progress made in {self.config.progress_timeout}s. Requests appear to be stuck.",
|
||||
request=dummy_request,
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
self._in_retry_loop_count -= 1
|
||||
self._maybe_log_progress()
|
||||
progress_task.cancel()
|
||||
|
||||
def _maybe_log_progress(self):
|
||||
current_time = time.time()
|
||||
elapsed_since_last_printed_progress = current_time - self._last_printed_progress
|
||||
finished = self._waiting_at_semaphore_count + self._in_retry_loop_count == 0
|
||||
if elapsed_since_last_printed_progress > 2 or finished:
|
||||
logger.debug(
|
||||
f"[{self.name}]: {self._waiting_at_semaphore_count} waiting, {self._in_retry_loop_count} in progress, {self._processed_count} completed"
|
||||
)
|
||||
if self._errors_since_last_retry:
|
||||
sorted_items = sorted(
|
||||
self._errors_since_last_retry.items(), key=lambda x: x[1], reverse=True
|
||||
)
|
||||
logger.debug(
|
||||
f"[{self.name}]: {self._retry_count} total retries, errors since last log: {sorted_items}"
|
||||
)
|
||||
self._last_printed_progress = current_time
|
||||
self._errors_since_last_retry.clear()
|
||||
|
||||
async def _execute_with_retry(
|
||||
self, func: Callable[..., Awaitable[T]], *args: Any, **kwargs: Any
|
||||
) -> T:
|
||||
"""Main retry logic."""
|
||||
# Fast path: skip all retry logic if disabled
|
||||
if not self.config.enable_retry_logic:
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
start_time = time.time()
|
||||
attempt_count = 0
|
||||
while True:
|
||||
current_time = time.time()
|
||||
self._maybe_log_progress()
|
||||
try:
|
||||
attempt_count += 1
|
||||
logger.debug(f"Attempting request (attempt #{attempt_count})")
|
||||
result = await func(*args, **kwargs)
|
||||
|
||||
except Exception as e:
|
||||
exception_str = f"{type(e).__name__}: {str(e) or 'No error message'}"
|
||||
self._errors_since_last_retry[exception_str] += 1
|
||||
should_retry = self._should_retry(e)
|
||||
|
||||
if telemetry := self.get_telemetry():
|
||||
current_time = time.time()
|
||||
telemetry.log(
|
||||
"RetryHandler.execute.exception",
|
||||
event_data={
|
||||
"func": getattr(
|
||||
func, "__qualname__", getattr(func, "__name__", type(func).__name__)
|
||||
),
|
||||
"exception": str(e),
|
||||
"exception_type": type(e).__name__,
|
||||
"exception_stack": "".join(
|
||||
traceback.format_exception(type(e), e, e.__traceback__)
|
||||
)
|
||||
if e.__traceback__
|
||||
else None,
|
||||
"status_code": getattr(e, "status_code", None),
|
||||
"should_retry": should_retry,
|
||||
"attempt_count": attempt_count,
|
||||
"start_time": start_time,
|
||||
"current_time": current_time,
|
||||
"elapsed_time": current_time - start_time,
|
||||
},
|
||||
severity="WARNING" if should_retry else "ERROR",
|
||||
)
|
||||
|
||||
if not should_retry:
|
||||
logger.error(f"Request failed with non-retryable error: {exception_str}")
|
||||
raise
|
||||
|
||||
self._log_retry_reason(e, attempt_count)
|
||||
self._retry_count += 1
|
||||
|
||||
# Calculate retry delay with exponential backoff and jitter
|
||||
retry_delay = self._calculate_retry_delay(attempt_count - 1)
|
||||
logger.debug(f"Retrying in {retry_delay:.2f}s")
|
||||
await asyncio.sleep(retry_delay)
|
||||
else:
|
||||
logger.debug(f"Request succeeded after {attempt_count} attempts")
|
||||
self._processed_count += 1
|
||||
return result
|
||||
|
||||
def _should_retry(self, exception: Exception) -> bool:
|
||||
"""Determine if an exception should trigger a retry."""
|
||||
# Check if it's a generally retryable exception type
|
||||
if isinstance(exception, self.config.retryable_exceptions):
|
||||
return True
|
||||
|
||||
# Check for API status errors with retryable status codes
|
||||
if isinstance(exception, tinker.APIStatusError):
|
||||
return is_retryable_status_code(exception.status_code)
|
||||
|
||||
return False
|
||||
|
||||
def _log_retry_reason(self, exception: Exception, attempt_count: int):
|
||||
"""Log the reason for retrying."""
|
||||
if isinstance(exception, asyncio.TimeoutError):
|
||||
logger.debug("Request timed out")
|
||||
elif isinstance(exception, tinker.APIConnectionError):
|
||||
logger.debug(f"Request failed with connection error: {exception}")
|
||||
elif isinstance(exception, tinker.APIStatusError):
|
||||
logger.debug(
|
||||
f"Request attempt #{attempt_count} failed with status {exception.status_code}"
|
||||
)
|
||||
else:
|
||||
logger.debug(f"Request attempt #{attempt_count} failed with error: {exception}")
|
||||
|
||||
def _calculate_retry_delay(self, attempt: int) -> float:
|
||||
"""Calculate retry delay with exponential backoff and jitter."""
|
||||
delay = self.config.retry_delay_max
|
||||
try:
|
||||
delay = min(self.config.retry_delay_base * (2**attempt), self.config.retry_delay_max)
|
||||
except OverflowError:
|
||||
# There are two possible overflow errors:
|
||||
# (1) `min` tries to convert the value to a float, which can overflow
|
||||
# if the integer value gets too large
|
||||
# (2) If the attempt number is too large, the `2 ** attempt` will overflow
|
||||
delay = self.config.retry_delay_max
|
||||
|
||||
jitter = delay * self.config.jitter_factor * (2 * random.random() - 1)
|
||||
# Ensure the final delay doesn't exceed the maximum, even with jitter
|
||||
return max(0, min(delay + jitter, self.config.retry_delay_max))
|
||||
|
||||
def get_telemetry(self) -> Telemetry | None:
|
||||
return self._telemetry
|
||||
3
src/tinker/lib/retryable_exception.py
Normal file
3
src/tinker/lib/retryable_exception.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
class RetryableException(Exception):
|
||||
def __init__(self, message: str):
|
||||
super().__init__(message)
|
||||
79
src/tinker/lib/sync_only.py
Normal file
79
src/tinker/lib/sync_only.py
Normal file
|
|
@ -0,0 +1,79 @@
|
|||
"""
|
||||
Decorator to prevent synchronous methods from being called in async contexts.
|
||||
|
||||
This helps users avoid a common footgun where calling sync methods from async code
|
||||
can cause deadlocks and performance issues.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import traceback
|
||||
from functools import wraps
|
||||
from typing import Any, Callable, TypeVar
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def is_jupyter() -> bool:
|
||||
"""Check if code is running in a Jupyter notebook."""
|
||||
try:
|
||||
get_ipython # type: ignore
|
||||
except NameError:
|
||||
return False
|
||||
shell = get_ipython().__class__.__name__ # type: ignore
|
||||
if shell in ("ZMQInteractiveShell", "Shell"):
|
||||
return True # Jupyter notebook or qtconsole
|
||||
return False # Other type of shell
|
||||
|
||||
|
||||
def is_in_async_context() -> bool:
|
||||
try:
|
||||
asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def make_error_message( # noqa: UP047
|
||||
func: Callable[..., T], args: tuple[Any, ...], kwargs: dict[str, Any]
|
||||
) -> str:
|
||||
# If we get here, we're in an async context - this is bad!
|
||||
method_name = func.__name__
|
||||
async_method_name = f"{method_name}_async"
|
||||
|
||||
# Get the class name for a better error message
|
||||
class_name = ""
|
||||
if args and hasattr(args[0], "__class__"):
|
||||
class_name = f"{args[0].__class__.__name__}."
|
||||
|
||||
return (
|
||||
f"Synchronous method '{class_name}{method_name}()' called from async context. "
|
||||
f"Use '{class_name}{async_method_name}()' instead.\n"
|
||||
f"Calling sync methods from async code can cause deadlocks and performance issues."
|
||||
)
|
||||
|
||||
|
||||
def sync_only(func: Callable[..., T]) -> Callable[..., T]: # noqa: UP047
|
||||
"""Decorator to ensure a method is only called from sync context.
|
||||
|
||||
This helps prevent a common footgun where users accidentally call
|
||||
sync methods from async code, which can cause deadlocks and performance issues.
|
||||
|
||||
Assumes (in error message) that the wrapped method has a corresponding method name
|
||||
{method_name}_async.
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args: Any, **kwargs: Any) -> T:
|
||||
if is_in_async_context() and not is_jupyter():
|
||||
error_message = make_error_message(func, args, kwargs)
|
||||
logger.warning(error_message)
|
||||
logger.warning(
|
||||
f"===== Stack for calling sync from async ===== \n{traceback.format_stack()}\n ==========="
|
||||
)
|
||||
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
383
src/tinker/lib/telemetry.py
Normal file
383
src/tinker/lib/telemetry.py
Normal file
|
|
@ -0,0 +1,383 @@
|
|||
import asyncio
|
||||
import contextlib
|
||||
import functools
|
||||
import inspect
|
||||
import logging
|
||||
import os
|
||||
import platform
|
||||
import threading
|
||||
import traceback
|
||||
from collections import deque
|
||||
from collections.abc import Awaitable
|
||||
from datetime import datetime, timezone
|
||||
from typing import (
|
||||
Callable,
|
||||
ParamSpec,
|
||||
TypeVar,
|
||||
cast,
|
||||
overload,
|
||||
)
|
||||
from uuid import uuid4
|
||||
|
||||
from tinker._exceptions import APIError
|
||||
from tinker._version import __version__
|
||||
from tinker.types.generic_event import GenericEvent
|
||||
from tinker.types.session_end_event import SessionEndEvent
|
||||
from tinker.types.session_start_event import SessionStartEvent
|
||||
from tinker.types.severity import Severity
|
||||
from tinker.types.telemetry_batch import TelemetryBatch
|
||||
from tinker.types.telemetry_event import TelemetryEvent
|
||||
from tinker.types.telemetry_response import TelemetryResponse
|
||||
from tinker.types.telemetry_send_params import TelemetrySendParams
|
||||
from tinker.types.unhandled_exception_event import UnhandledExceptionEvent
|
||||
|
||||
from .async_tinker_provider import AsyncTinkerProvider
|
||||
from .client_connection_pool_type import ClientConnectionPoolType
|
||||
from .sync_only import sync_only
|
||||
from .telemetry_provider import TelemetryProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MAX_BATCH_SIZE: int = 100
|
||||
FLUSH_INTERVAL: float = 10.0
|
||||
FLUSH_TIMEOUT: float = 30.0
|
||||
MAX_QUEUE_SIZE: int = 10000
|
||||
HTTP_TIMEOUT_SECONDS: float = 5.0
|
||||
|
||||
|
||||
class Telemetry:
|
||||
def __init__(self, tinker_provider: AsyncTinkerProvider, session_id: str):
|
||||
self._tinker_provider: AsyncTinkerProvider = tinker_provider
|
||||
self._session_id: str = session_id
|
||||
self._session_start: datetime = datetime.now(timezone.utc)
|
||||
self._session_index: int = 0
|
||||
self._session_index_lock: threading.Lock = threading.Lock()
|
||||
self._queue: deque[TelemetryEvent] = deque()
|
||||
self._queue_lock: threading.Lock = threading.Lock()
|
||||
self._task: asyncio.Task[None] | None = None
|
||||
self._flush_event: asyncio.Event | None = None
|
||||
self._push_counter: int = 0
|
||||
self._flush_counter: int = 0
|
||||
self._counter_lock: threading.Lock = threading.Lock()
|
||||
_ = self._log(self._session_start_event())
|
||||
self._start()
|
||||
|
||||
def _start(self):
|
||||
def cb():
|
||||
self._flush_event = asyncio.Event()
|
||||
self._task = asyncio.create_task(self._periodic_flush(), name="tinker-telemetry")
|
||||
|
||||
_ = self._tinker_provider.get_loop().call_soon_threadsafe(cb)
|
||||
|
||||
def stop(self):
|
||||
def cb():
|
||||
if task := self._task:
|
||||
_ = task.cancel()
|
||||
|
||||
_ = self._tinker_provider.get_loop().call_soon_threadsafe(cb)
|
||||
|
||||
async def _periodic_flush(self):
|
||||
while True:
|
||||
if self._flush_event:
|
||||
try:
|
||||
_ = await asyncio.wait_for(self._flush_event.wait(), timeout=FLUSH_INTERVAL)
|
||||
except TimeoutError:
|
||||
pass
|
||||
finally:
|
||||
self._flush_event.clear()
|
||||
await self._flush()
|
||||
|
||||
async def _flush(self):
|
||||
while True:
|
||||
with self._queue_lock:
|
||||
if not self._queue:
|
||||
break
|
||||
batch_size = min(MAX_BATCH_SIZE, len(self._queue))
|
||||
events = [self._queue.popleft() for _ in range(batch_size)]
|
||||
batch = self._batch(events)
|
||||
try:
|
||||
_ = await self._send_batch_with_retry(batch)
|
||||
finally:
|
||||
# increment counter even if we fail to send the batch so we're not blocking
|
||||
# on a flush for non-APIErrors (e.g. missing API key)
|
||||
with self._counter_lock:
|
||||
self._flush_counter += len(events)
|
||||
|
||||
def _trigger_flush(self):
|
||||
if self._flush_event:
|
||||
_ = self._tinker_provider.get_loop().call_soon_threadsafe(self._flush_event.set)
|
||||
|
||||
async def _wait_until_drained(self) -> bool:
|
||||
with self._counter_lock:
|
||||
target_count = self._push_counter
|
||||
start = asyncio.get_event_loop().time()
|
||||
while asyncio.get_event_loop().time() - start < FLUSH_TIMEOUT:
|
||||
with self._counter_lock:
|
||||
if self._flush_counter >= target_count:
|
||||
return True
|
||||
await asyncio.sleep(0.1)
|
||||
return False
|
||||
|
||||
def _wait_until_drained_sync(self) -> bool:
|
||||
try:
|
||||
return asyncio.run_coroutine_threadsafe(
|
||||
self._wait_until_drained(), self._tinker_provider.get_loop()
|
||||
).result(timeout=FLUSH_TIMEOUT)
|
||||
except (TimeoutError, asyncio.CancelledError):
|
||||
return False
|
||||
|
||||
async def _send_batch_with_retry(self, batch: TelemetryBatch) -> TelemetryResponse:
|
||||
while True:
|
||||
try:
|
||||
return await self._send_batch(batch)
|
||||
except APIError as e:
|
||||
logger.warning("Failed to send telemetry batch", exc_info=e)
|
||||
await asyncio.sleep(1)
|
||||
continue
|
||||
|
||||
async def _send_batch(self, batch: TelemetryBatch) -> TelemetryResponse:
|
||||
with self._tinker_provider.aclient(ClientConnectionPoolType.TELEMETRY) as client:
|
||||
params = _to_send_params(batch)
|
||||
return await client.telemetry.send(**params, timeout=HTTP_TIMEOUT_SECONDS)
|
||||
|
||||
def _log(self, *events: TelemetryEvent) -> bool:
|
||||
with self._queue_lock:
|
||||
if len(self._queue) + len(events) > MAX_QUEUE_SIZE:
|
||||
logger.warning("Telemetry queue full, dropping events")
|
||||
return False
|
||||
self._queue.extend(events)
|
||||
with self._counter_lock:
|
||||
self._push_counter += len(events)
|
||||
return True
|
||||
|
||||
def log(
|
||||
self,
|
||||
event_name: str, # should be low cardinality
|
||||
event_data: dict[str, object] | None = None,
|
||||
severity: Severity = "INFO",
|
||||
) -> bool:
|
||||
return self._log(self._generic_event(event_name, event_data, severity))
|
||||
|
||||
async def log_exception(self, exception: BaseException, severity: Severity = "ERROR") -> bool:
|
||||
logged = self._log(self._exception_event(exception, severity))
|
||||
# trigger flush but don't block on it
|
||||
self._trigger_flush()
|
||||
return logged
|
||||
|
||||
async def log_fatal_exception(
|
||||
self, exception: BaseException, severity: Severity = "ERROR"
|
||||
) -> bool:
|
||||
logged = self._log(self._exception_event(exception, severity), self._session_end_event())
|
||||
self._trigger_flush()
|
||||
# wait for the flush to complete
|
||||
_ = await self._wait_until_drained()
|
||||
if logged:
|
||||
self._notify_exception_logged()
|
||||
return logged
|
||||
|
||||
@sync_only
|
||||
def log_exception_sync(self, exception: BaseException, severity: Severity = "ERROR") -> bool:
|
||||
logged = self._log(self._exception_event(exception, severity))
|
||||
# trigger flush but don't block on it
|
||||
self._trigger_flush()
|
||||
return logged
|
||||
|
||||
@sync_only
|
||||
def log_fatal_exception_sync(
|
||||
self, exception: BaseException, severity: Severity = "ERROR"
|
||||
) -> bool:
|
||||
logged = self._log(self._exception_event(exception, severity), self._session_end_event())
|
||||
self._trigger_flush()
|
||||
# wait for the flush to complete
|
||||
if _current_loop() is None:
|
||||
_ = self._wait_until_drained_sync()
|
||||
if logged:
|
||||
self._notify_exception_logged()
|
||||
return logged
|
||||
|
||||
def _notify_exception_logged(self):
|
||||
logger.info(f"Exception logged for session ID: {self._session_id}")
|
||||
|
||||
def _batch(self, events: list[TelemetryEvent]) -> TelemetryBatch:
|
||||
return TelemetryBatch(
|
||||
platform=platform.system(),
|
||||
sdk_version=__version__,
|
||||
session_id=self._session_id,
|
||||
events=events,
|
||||
)
|
||||
|
||||
def _generic_event(
|
||||
self,
|
||||
event_name: str,
|
||||
event_data: dict[str, object] | None = None,
|
||||
severity: Severity = "INFO",
|
||||
) -> GenericEvent:
|
||||
return GenericEvent(
|
||||
event="GENERIC_EVENT",
|
||||
event_id=str(uuid4()),
|
||||
event_session_index=self._next_session_index(),
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
severity=severity,
|
||||
event_name=event_name,
|
||||
event_data=event_data or {},
|
||||
)
|
||||
|
||||
def _session_start_event(self) -> SessionStartEvent:
|
||||
return SessionStartEvent(
|
||||
event="SESSION_START",
|
||||
event_id=str(uuid4()),
|
||||
event_session_index=self._next_session_index(),
|
||||
timestamp=self._session_start,
|
||||
severity="INFO",
|
||||
)
|
||||
|
||||
def _session_end_event(self) -> SessionEndEvent:
|
||||
end_time = datetime.now(timezone.utc)
|
||||
return SessionEndEvent(
|
||||
event="SESSION_END",
|
||||
event_id=str(uuid4()),
|
||||
event_session_index=self._next_session_index(),
|
||||
timestamp=end_time,
|
||||
severity="INFO",
|
||||
duration=str(end_time - self._session_start),
|
||||
)
|
||||
|
||||
def _exception_event(
|
||||
self, exception: BaseException, severity: Severity
|
||||
) -> UnhandledExceptionEvent:
|
||||
return UnhandledExceptionEvent(
|
||||
event="UNHANDLED_EXCEPTION",
|
||||
event_id=str(uuid4()),
|
||||
event_session_index=self._next_session_index(),
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
severity=severity,
|
||||
error_type=exception.__class__.__name__,
|
||||
error_message=str(exception),
|
||||
traceback="".join(
|
||||
traceback.format_exception(type(exception), exception, exception.__traceback__)
|
||||
)
|
||||
if exception.__traceback__
|
||||
else None,
|
||||
)
|
||||
|
||||
def _next_session_index(self) -> int:
|
||||
with self._session_index_lock:
|
||||
idx = self._session_index
|
||||
self._session_index += 1
|
||||
return idx
|
||||
|
||||
@contextlib.contextmanager
|
||||
def capture_exceptions(self, fatal: bool = False, severity: Severity = "ERROR"):
|
||||
try:
|
||||
yield
|
||||
except Exception as e:
|
||||
if fatal:
|
||||
_ = self.log_fatal_exception_sync(e, severity)
|
||||
else:
|
||||
_ = self.log_exception_sync(e, severity)
|
||||
raise
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def acapture_exceptions(self, fatal: bool = False, severity: Severity = "ERROR"):
|
||||
try:
|
||||
yield
|
||||
except Exception as e:
|
||||
if fatal:
|
||||
_ = await self.log_fatal_exception(e, severity)
|
||||
else:
|
||||
_ = await self.log_exception(e, severity)
|
||||
raise
|
||||
|
||||
|
||||
def _is_telemetry_enabled() -> bool:
|
||||
return os.environ.get("TINKER_TELEMETRY", "1").lower() in {
|
||||
"1",
|
||||
"true",
|
||||
"yes",
|
||||
"on",
|
||||
}
|
||||
|
||||
|
||||
def init_telemetry(tinker_provider: AsyncTinkerProvider, session_id: str) -> Telemetry | None:
|
||||
try:
|
||||
return Telemetry(tinker_provider, session_id) if _is_telemetry_enabled() else None
|
||||
except Exception as e:
|
||||
logger.warning(f"Error initializing telemetry: {e}")
|
||||
return None
|
||||
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
# Decorator to capture exceptions. Class must implement TelemetryProvider.
|
||||
# Pass fatal=True to log a session end event in addition to the exception.
|
||||
#
|
||||
# Example:
|
||||
# @capture_exceptions
|
||||
# def my_method(self):
|
||||
# pass
|
||||
#
|
||||
# @capture_exceptions(fatal=True, severity="CRITICAL")
|
||||
# def my_method(self):
|
||||
# pass
|
||||
@overload
|
||||
def capture_exceptions(
|
||||
func: Callable[P, R], *, fatal: bool = False, severity: Severity = "ERROR"
|
||||
) -> Callable[P, R]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def capture_exceptions(
|
||||
*, fatal: bool = False, severity: Severity = "ERROR"
|
||||
) -> Callable[[Callable[P, R]], Callable[P, R]]: ...
|
||||
|
||||
|
||||
def capture_exceptions(
|
||||
func: Callable[P, R] | None = None, *, fatal: bool = False, severity: Severity = "ERROR"
|
||||
) -> Callable[P, R] | Callable[[Callable[P, R]], Callable[P, R]]:
|
||||
def _get_telemetry(func: Callable[..., object], args: tuple[object, ...]) -> Telemetry | None:
|
||||
if args and isinstance(args[0], TelemetryProvider):
|
||||
return args[0].get_telemetry()
|
||||
with contextlib.suppress(TypeError, AttributeError):
|
||||
self = inspect.getclosurevars(func).nonlocals.get("self")
|
||||
if isinstance(self, TelemetryProvider):
|
||||
return self.get_telemetry()
|
||||
logger.warning("@capture_exceptions used without TelemetryProvider: %s", func.__name__)
|
||||
return None
|
||||
|
||||
def _decorate(func: Callable[P, R]) -> Callable[P, R]:
|
||||
if inspect.iscoroutinefunction(func):
|
||||
|
||||
@functools.wraps(func)
|
||||
async def _awrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
telemetry = _get_telemetry(func, args)
|
||||
if telemetry is None:
|
||||
return await cast(Callable[..., Awaitable[R]], func)(*args, **kwargs)
|
||||
async with telemetry.acapture_exceptions(fatal=fatal, severity=severity):
|
||||
return await cast(Callable[..., Awaitable[R]], func)(*args, **kwargs)
|
||||
|
||||
return cast(Callable[P, R], _awrapper)
|
||||
|
||||
@functools.wraps(func)
|
||||
def _wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
telemetry = _get_telemetry(func, args)
|
||||
if telemetry is None:
|
||||
return func(*args, **kwargs)
|
||||
with telemetry.capture_exceptions(fatal=fatal, severity=severity):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return cast(Callable[P, R], _wrapper)
|
||||
|
||||
return _decorate if func is None else _decorate(func)
|
||||
|
||||
|
||||
def _to_send_params(batch: TelemetryBatch) -> TelemetrySendParams:
|
||||
return cast(TelemetrySendParams, cast(object, batch.model_dump()))
|
||||
|
||||
|
||||
def _current_loop() -> asyncio.AbstractEventLoop | None:
|
||||
try:
|
||||
return asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
return None
|
||||
11
src/tinker/lib/telemetry_provider.py
Normal file
11
src/tinker/lib/telemetry_provider.py
Normal file
|
|
@ -0,0 +1,11 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Protocol, runtime_checkable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .telemetry import Telemetry
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class TelemetryProvider(Protocol):
|
||||
def get_telemetry(self) -> Telemetry | None: ...
|
||||
642
src/tinker/lib/telemetry_test.py
Normal file
642
src/tinker/lib/telemetry_test.py
Normal file
|
|
@ -0,0 +1,642 @@
|
|||
import asyncio
|
||||
import contextlib
|
||||
import os
|
||||
import platform
|
||||
import threading
|
||||
from collections.abc import Coroutine
|
||||
from concurrent.futures import Future as ConcurrentFuture
|
||||
from typing import Any, TypeVar, cast
|
||||
from unittest.mock import AsyncMock, MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from tinker.lib.public_interfaces.api_future import AwaitableConcurrentFuture
|
||||
from tinker.lib.client_connection_pool_type import ClientConnectionPoolType
|
||||
from tinker.lib.telemetry import (
|
||||
MAX_BATCH_SIZE,
|
||||
MAX_QUEUE_SIZE,
|
||||
Telemetry,
|
||||
_is_telemetry_enabled,
|
||||
capture_exceptions,
|
||||
init_telemetry,
|
||||
)
|
||||
from tinker.lib.telemetry_provider import TelemetryProvider
|
||||
from tinker.types.generic_event import GenericEvent
|
||||
from tinker.types.session_end_event import SessionEndEvent
|
||||
from tinker.types.session_start_event import SessionStartEvent
|
||||
from tinker.types.telemetry_batch import TelemetryBatch
|
||||
from tinker.types.telemetry_event import TelemetryEvent
|
||||
from tinker.types.telemetry_response import TelemetryResponse
|
||||
from tinker.types.unhandled_exception_event import UnhandledExceptionEvent
|
||||
|
||||
# pyright: reportMissingParameterType=false
|
||||
# pyright: reportOptionalMemberAccess=false
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class MockEventLoopProvider:
|
||||
def __init__(self):
|
||||
self.loop = Mock()
|
||||
self.loop.call_soon_threadsafe = Mock()
|
||||
|
||||
def get_loop(self):
|
||||
return self.loop
|
||||
|
||||
def run_coroutine_threadsafe(self, coro: Coroutine[Any, Any, T]):
|
||||
future = Mock()
|
||||
future.result = Mock(return_value=asyncio.run(coro))
|
||||
return future
|
||||
|
||||
|
||||
class MockAsyncTinkerProvider:
|
||||
def __init__(self):
|
||||
self.loop = Mock()
|
||||
self.callbacks = []
|
||||
self.loop.call_soon_threadsafe = Mock(side_effect=lambda cb: self.callbacks.append(cb))
|
||||
self._client = MagicMock()
|
||||
self.telemetry_send_mock = AsyncMock(return_value=TelemetryResponse(status="accepted"))
|
||||
self._client.telemetry.send = self.telemetry_send_mock
|
||||
self._client_cm = MagicMock()
|
||||
self._client_cm.__enter__ = Mock(return_value=self._client)
|
||||
self._client_cm.__exit__ = Mock(return_value=None)
|
||||
|
||||
def execute_callbacks(self):
|
||||
for cb in self.callbacks:
|
||||
cb()
|
||||
self.callbacks.clear()
|
||||
|
||||
def get_loop(self):
|
||||
return self.loop
|
||||
|
||||
def run_coroutine_threadsafe(self, coro: Coroutine[Any, Any, T]):
|
||||
fut: ConcurrentFuture[Any] = ConcurrentFuture()
|
||||
|
||||
async def _runner():
|
||||
try:
|
||||
result = await coro
|
||||
fut.set_result(result)
|
||||
except Exception as e:
|
||||
fut.set_exception(e)
|
||||
|
||||
_ = asyncio.get_event_loop_policy().get_event_loop().create_task(_runner())
|
||||
return AwaitableConcurrentFuture(fut)
|
||||
|
||||
def aclient(self, client_pool_type: ClientConnectionPoolType):
|
||||
_ = client_pool_type
|
||||
return self._client_cm
|
||||
|
||||
|
||||
class TestTelemetryClass:
|
||||
def setup_method(self):
|
||||
self.tinker_provider = MockAsyncTinkerProvider()
|
||||
self.telemetry = Telemetry(self.tinker_provider, session_id="test-session-id")
|
||||
|
||||
def teardown_method(self):
|
||||
if hasattr(self, "telemetry"):
|
||||
self.telemetry.stop()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_initialization(self):
|
||||
self.tinker_provider.execute_callbacks()
|
||||
assert self.telemetry._session_index == 1
|
||||
assert len(self.telemetry._queue) == 1
|
||||
assert isinstance(self.telemetry._queue[0], SessionStartEvent)
|
||||
assert isinstance(self.telemetry._flush_event, asyncio.Event)
|
||||
|
||||
def test_log_single_event(self):
|
||||
event = self.telemetry._session_end_event()
|
||||
result = self.telemetry._log(event)
|
||||
assert result is True
|
||||
assert len(self.telemetry._queue) == 2
|
||||
assert self.telemetry._queue[-1] == event
|
||||
|
||||
def test_log_multiple_events(self):
|
||||
event1 = self.telemetry._session_end_event()
|
||||
event2 = self.telemetry._exception_event(ValueError("test"), "ERROR")
|
||||
result = self.telemetry._log(event1, event2)
|
||||
assert result is True
|
||||
assert len(self.telemetry._queue) == 3
|
||||
assert self.telemetry._queue[-2] == event1
|
||||
assert self.telemetry._queue[-1] == event2
|
||||
|
||||
def test_log_generic_event_default(self):
|
||||
idx_before = self.telemetry._session_index
|
||||
result = self.telemetry.log("test-event")
|
||||
assert result is True
|
||||
assert len(self.telemetry._queue) == 2
|
||||
event = self.telemetry._queue[-1]
|
||||
assert isinstance(event, GenericEvent)
|
||||
assert event.event == "GENERIC_EVENT"
|
||||
assert event.event_name == "test-event"
|
||||
assert event.severity == "INFO"
|
||||
assert event.event_data == {}
|
||||
assert event.event_session_index == idx_before
|
||||
assert isinstance(event.event_id, str) and event.event_id
|
||||
|
||||
def test_log_generic_event_custom(self):
|
||||
idx_before = self.telemetry._session_index
|
||||
payload: dict[str, object] = {"a": 1, "b": "x"}
|
||||
result = self.telemetry.log("custom-event", event_data=payload, severity="WARNING")
|
||||
assert result is True
|
||||
assert len(self.telemetry._queue) == 2
|
||||
event = self.telemetry._queue[-1]
|
||||
assert isinstance(event, GenericEvent)
|
||||
assert event.event == "GENERIC_EVENT"
|
||||
assert event.event_name == "custom-event"
|
||||
assert event.severity == "WARNING"
|
||||
assert event.event_data == payload
|
||||
assert event.event_session_index == idx_before
|
||||
|
||||
def test_log_queue_full(self):
|
||||
initial_size = len(self.telemetry._queue)
|
||||
events_to_add = MAX_QUEUE_SIZE - initial_size - 1
|
||||
for _ in range(events_to_add):
|
||||
self.telemetry._queue.append(self.telemetry._session_end_event())
|
||||
assert len(self.telemetry._queue) == MAX_QUEUE_SIZE - 1
|
||||
event1 = self.telemetry._session_end_event()
|
||||
event2 = self.telemetry._session_end_event()
|
||||
with patch("tinker.lib.telemetry.logger") as mock_logger:
|
||||
result = self.telemetry._log(event1, event2)
|
||||
assert result is False
|
||||
assert len(self.telemetry._queue) == MAX_QUEUE_SIZE - 1
|
||||
mock_logger.warning.assert_called_once_with("Telemetry queue full, dropping events")
|
||||
|
||||
def test_batch_creation(self):
|
||||
events = [
|
||||
self.telemetry._session_start_event(),
|
||||
self.telemetry._session_end_event(),
|
||||
]
|
||||
batch = self.telemetry._batch(cast(list[TelemetryEvent], events))
|
||||
assert isinstance(batch, TelemetryBatch)
|
||||
assert batch.platform == platform.system()
|
||||
assert batch.session_id == str(self.telemetry._session_id)
|
||||
assert batch.events == events
|
||||
assert batch.sdk_version is not None
|
||||
|
||||
def test_log_exception_sync(self):
|
||||
try:
|
||||
raise RuntimeError("Test exception")
|
||||
except RuntimeError as e:
|
||||
with patch.object(self.telemetry, "_trigger_flush") as mock_trigger:
|
||||
result = self.telemetry.log_exception_sync(e, "ERROR")
|
||||
|
||||
assert result is True
|
||||
assert len(self.telemetry._queue) == 2
|
||||
mock_trigger.assert_called_once()
|
||||
exception_event = self.telemetry._queue[-1]
|
||||
assert isinstance(exception_event, UnhandledExceptionEvent)
|
||||
assert exception_event.error_type == "RuntimeError"
|
||||
assert exception_event.error_message == "Test exception"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_log_exception_async(self):
|
||||
try:
|
||||
raise RuntimeError("Test exception")
|
||||
except RuntimeError as e:
|
||||
with patch.object(self.telemetry, "_trigger_flush") as mock_trigger:
|
||||
result = await self.telemetry.log_exception(e, "ERROR")
|
||||
|
||||
assert result is True
|
||||
assert len(self.telemetry._queue) == 2
|
||||
mock_trigger.assert_called_once()
|
||||
exception_event = self.telemetry._queue[-1]
|
||||
assert isinstance(exception_event, UnhandledExceptionEvent)
|
||||
assert exception_event.error_type == "RuntimeError"
|
||||
assert exception_event.error_message == "Test exception"
|
||||
|
||||
def test_log_fatal_exception_sync(self):
|
||||
try:
|
||||
raise RuntimeError("Fatal error")
|
||||
except RuntimeError as e:
|
||||
with patch.object(self.telemetry, "_trigger_flush") as mock_trigger:
|
||||
with patch.object(
|
||||
self.telemetry, "_wait_until_drained_sync", return_value=True
|
||||
) as mock_wait:
|
||||
result = self.telemetry.log_fatal_exception_sync(e, "CRITICAL")
|
||||
|
||||
assert result is True
|
||||
assert len(self.telemetry._queue) == 3
|
||||
mock_trigger.assert_called_once()
|
||||
mock_wait.assert_called_once()
|
||||
exception_event = self.telemetry._queue[-2]
|
||||
assert isinstance(exception_event, UnhandledExceptionEvent)
|
||||
assert exception_event.severity == "CRITICAL"
|
||||
end_event = self.telemetry._queue[-1]
|
||||
assert isinstance(end_event, SessionEndEvent)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_log_fatal_exception_async(self):
|
||||
try:
|
||||
raise RuntimeError("Fatal error")
|
||||
except RuntimeError as e:
|
||||
with patch.object(self.telemetry, "_trigger_flush") as mock_trigger:
|
||||
with patch.object(
|
||||
self.telemetry, "_wait_until_drained", new_callable=AsyncMock, return_value=True
|
||||
) as mock_wait:
|
||||
result = await self.telemetry.log_fatal_exception(e, "CRITICAL")
|
||||
|
||||
assert result is True
|
||||
assert len(self.telemetry._queue) == 3
|
||||
mock_trigger.assert_called_once()
|
||||
mock_wait.assert_called_once()
|
||||
exception_event = self.telemetry._queue[-2]
|
||||
assert isinstance(exception_event, UnhandledExceptionEvent)
|
||||
assert exception_event.severity == "CRITICAL"
|
||||
end_event = self.telemetry._queue[-1]
|
||||
assert isinstance(end_event, SessionEndEvent)
|
||||
|
||||
def test_capture_exceptions_context_manager(self):
|
||||
with patch.object(self.telemetry, "_trigger_flush"):
|
||||
with pytest.raises(ValueError):
|
||||
with self.telemetry.capture_exceptions():
|
||||
raise ValueError("Test error")
|
||||
|
||||
assert len(self.telemetry._queue) == 2
|
||||
exception_event = self.telemetry._queue[-1]
|
||||
assert isinstance(exception_event, UnhandledExceptionEvent)
|
||||
assert exception_event.error_type == "ValueError"
|
||||
|
||||
def test_capture_exceptions_context_manager_fatal(self):
|
||||
with patch.object(self.telemetry, "_trigger_flush"):
|
||||
with patch.object(self.telemetry, "_wait_until_drained_sync", return_value=True):
|
||||
with pytest.raises(ValueError):
|
||||
with self.telemetry.capture_exceptions(fatal=True, severity="CRITICAL"):
|
||||
raise ValueError("Fatal error")
|
||||
|
||||
assert len(self.telemetry._queue) == 3
|
||||
exception_event = self.telemetry._queue[-2]
|
||||
assert isinstance(exception_event, UnhandledExceptionEvent)
|
||||
assert exception_event.severity == "CRITICAL"
|
||||
end_event = self.telemetry._queue[-1]
|
||||
assert isinstance(end_event, SessionEndEvent)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_acapture_exceptions_context_manager(self):
|
||||
with patch.object(self.telemetry, "_trigger_flush"):
|
||||
with pytest.raises(ValueError):
|
||||
async with self.telemetry.acapture_exceptions():
|
||||
raise ValueError("Async test error")
|
||||
|
||||
assert len(self.telemetry._queue) == 2
|
||||
exception_event = self.telemetry._queue[-1]
|
||||
assert isinstance(exception_event, UnhandledExceptionEvent)
|
||||
assert exception_event.error_type == "ValueError"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_acapture_exceptions_context_manager_fatal(self):
|
||||
with patch.object(self.telemetry, "_trigger_flush"):
|
||||
with patch.object(
|
||||
self.telemetry, "_wait_until_drained", new_callable=AsyncMock, return_value=True
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
async with self.telemetry.acapture_exceptions(fatal=True):
|
||||
raise ValueError("Async fatal error")
|
||||
assert len(self.telemetry._queue) == 3
|
||||
exception_event = self.telemetry._queue[-2]
|
||||
assert isinstance(exception_event, UnhandledExceptionEvent)
|
||||
end_event = self.telemetry._queue[-1]
|
||||
assert isinstance(end_event, SessionEndEvent)
|
||||
|
||||
|
||||
class TestTelemetryEnvironment:
|
||||
@pytest.mark.parametrize(
|
||||
"env_value,expected",
|
||||
[
|
||||
("1", True),
|
||||
("true", True),
|
||||
("True", True),
|
||||
("TRUE", True),
|
||||
("yes", True),
|
||||
("Yes", True),
|
||||
("on", True),
|
||||
("ON", True),
|
||||
("0", False),
|
||||
("false", False),
|
||||
("no", False),
|
||||
("off", False),
|
||||
("", False),
|
||||
("random", False),
|
||||
],
|
||||
)
|
||||
def test_is_telemetry_enabled(self, env_value, expected):
|
||||
with patch.dict(os.environ, {"TINKER_TELEMETRY": env_value}):
|
||||
assert _is_telemetry_enabled() == expected
|
||||
|
||||
def test_is_telemetry_enabled_not_set(self):
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
assert _is_telemetry_enabled() is True
|
||||
|
||||
def test_init_telemetry_enabled(self):
|
||||
with patch.dict(os.environ, {"TINKER_TELEMETRY": "1"}):
|
||||
tinker_provider = MockAsyncTinkerProvider()
|
||||
telemetry = init_telemetry(tinker_provider, session_id="test-session-id")
|
||||
assert telemetry is not None
|
||||
assert isinstance(telemetry, Telemetry)
|
||||
telemetry.stop()
|
||||
|
||||
def test_init_telemetry_disabled(self):
|
||||
with patch.dict(os.environ, {"TINKER_TELEMETRY": "0"}):
|
||||
tinker_provider = MockAsyncTinkerProvider()
|
||||
telemetry = init_telemetry(tinker_provider, session_id="test-session-id")
|
||||
assert telemetry is None
|
||||
|
||||
def test_init_telemetry_with_exception(self):
|
||||
with patch.dict(os.environ, {"TINKER_TELEMETRY": "1"}):
|
||||
tinker_provider = Mock()
|
||||
tinker_provider.get_loop.side_effect = Exception("Init error")
|
||||
with patch("tinker.lib.telemetry.logger") as mock_logger:
|
||||
telemetry = init_telemetry(tinker_provider, session_id="test-session-id")
|
||||
assert telemetry is None
|
||||
mock_logger.warning.assert_called_once()
|
||||
assert "Error initializing telemetry" in str(mock_logger.warning.call_args)
|
||||
|
||||
|
||||
class TestCaptureExceptionsDecorator:
|
||||
class MockTelemetryProvider:
|
||||
def __init__(self):
|
||||
self.telemetry = Mock()
|
||||
self.telemetry.capture_exceptions = Mock()
|
||||
self.telemetry.acapture_exceptions = Mock()
|
||||
|
||||
def get_telemetry(self):
|
||||
return self.telemetry
|
||||
|
||||
def test_decorator_on_sync_function(self):
|
||||
provider = self.MockTelemetryProvider()
|
||||
|
||||
@capture_exceptions
|
||||
def test_func(self):
|
||||
return "success"
|
||||
|
||||
provider.telemetry.capture_exceptions.return_value.__enter__ = Mock()
|
||||
provider.telemetry.capture_exceptions.return_value.__exit__ = Mock(return_value=False)
|
||||
result = test_func(provider)
|
||||
assert result == "success"
|
||||
provider.telemetry.capture_exceptions.assert_called_once_with(fatal=False, severity="ERROR")
|
||||
|
||||
def test_decorator_on_sync_function_with_exception(self):
|
||||
provider = self.MockTelemetryProvider()
|
||||
|
||||
@capture_exceptions(fatal=True, severity="CRITICAL")
|
||||
def test_func(self):
|
||||
raise ValueError("Test error")
|
||||
|
||||
provider.telemetry.capture_exceptions.return_value.__enter__ = Mock()
|
||||
provider.telemetry.capture_exceptions.return_value.__exit__ = Mock(return_value=False)
|
||||
with pytest.raises(ValueError, match="Test error"):
|
||||
test_func(provider)
|
||||
provider.telemetry.capture_exceptions.assert_called_once_with(
|
||||
fatal=True, severity="CRITICAL"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decorator_on_async_function(self):
|
||||
provider = self.MockTelemetryProvider()
|
||||
|
||||
@capture_exceptions
|
||||
async def test_func(self):
|
||||
return "async success"
|
||||
|
||||
async_cm = AsyncMock()
|
||||
async_cm.__aenter__ = AsyncMock(return_value=None)
|
||||
async_cm.__aexit__ = AsyncMock(return_value=False)
|
||||
provider.telemetry.acapture_exceptions.return_value = async_cm
|
||||
result = await test_func(provider)
|
||||
assert result == "async success"
|
||||
provider.telemetry.acapture_exceptions.assert_called_once_with(
|
||||
fatal=False, severity="ERROR"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decorator_on_async_function_with_exception(self):
|
||||
provider = self.MockTelemetryProvider()
|
||||
|
||||
@capture_exceptions(severity="WARNING")
|
||||
async def test_func(self):
|
||||
raise RuntimeError("Async error")
|
||||
|
||||
async_cm = AsyncMock()
|
||||
async_cm.__aenter__ = AsyncMock(return_value=None)
|
||||
async_cm.__aexit__ = AsyncMock(return_value=False)
|
||||
provider.telemetry.acapture_exceptions.return_value = async_cm
|
||||
with pytest.raises(RuntimeError, match="Async error"):
|
||||
await test_func(provider)
|
||||
provider.telemetry.acapture_exceptions.assert_called_once_with(
|
||||
fatal=False, severity="WARNING"
|
||||
)
|
||||
|
||||
def test_decorator_without_telemetry_provider(self):
|
||||
@capture_exceptions
|
||||
def test_func():
|
||||
return "no provider"
|
||||
|
||||
result = test_func()
|
||||
assert result == "no provider"
|
||||
|
||||
def test_decorator_with_non_provider_self(self):
|
||||
class NonProvider:
|
||||
@capture_exceptions
|
||||
def test_method(self):
|
||||
return "not a provider"
|
||||
|
||||
obj = NonProvider()
|
||||
result = obj.test_method()
|
||||
assert result == "not a provider"
|
||||
|
||||
def test_decorator_as_plain_decorator(self):
|
||||
@capture_exceptions
|
||||
def test_func():
|
||||
return "plain decorator"
|
||||
|
||||
result = test_func()
|
||||
assert result == "plain decorator"
|
||||
|
||||
def test_decorator_with_parentheses(self):
|
||||
@capture_exceptions()
|
||||
def test_func():
|
||||
return "decorator with parens"
|
||||
|
||||
result = test_func()
|
||||
assert result == "decorator with parens"
|
||||
|
||||
def test_decorator_on_inner_sync_function_closing_over_self(self):
|
||||
provider = self.MockTelemetryProvider()
|
||||
|
||||
class Wrapper:
|
||||
def __init__(self, p: Any):
|
||||
self.p: Any = p
|
||||
|
||||
@capture_exceptions
|
||||
def outer(self) -> str:
|
||||
@capture_exceptions
|
||||
def inner() -> str:
|
||||
# reference `self` so it is captured in the closure
|
||||
return "ok" if self else "bad"
|
||||
|
||||
return inner()
|
||||
|
||||
def get_telemetry(self) -> Telemetry | None:
|
||||
return self.p.get_telemetry()
|
||||
|
||||
wrapper = Wrapper(provider)
|
||||
provider.telemetry.capture_exceptions.return_value.__enter__ = Mock()
|
||||
provider.telemetry.capture_exceptions.return_value.__exit__ = Mock(return_value=False)
|
||||
result = wrapper.outer()
|
||||
assert result == "ok"
|
||||
# Called twice: once for outer, once for inner via closure lookup
|
||||
assert provider.telemetry.capture_exceptions.call_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decorator_on_inner_async_function_closing_over_self(self):
|
||||
provider = self.MockTelemetryProvider()
|
||||
|
||||
class Wrapper:
|
||||
def __init__(self, p: Any):
|
||||
self.p: Any = p
|
||||
|
||||
@capture_exceptions
|
||||
async def outer(self) -> str:
|
||||
@capture_exceptions
|
||||
async def inner() -> str:
|
||||
# reference `self` so it is captured in the closure
|
||||
return "ok-async" if self else "bad"
|
||||
|
||||
return await inner()
|
||||
|
||||
def get_telemetry(self) -> Telemetry | None:
|
||||
return self.p.get_telemetry()
|
||||
|
||||
wrapper = Wrapper(provider)
|
||||
async_cm = AsyncMock()
|
||||
async_cm.__aenter__ = AsyncMock(return_value=None)
|
||||
async_cm.__aexit__ = AsyncMock(return_value=False)
|
||||
provider.telemetry.acapture_exceptions.return_value = async_cm
|
||||
result = await wrapper.outer()
|
||||
assert result == "ok-async"
|
||||
# Called twice: once for outer, once for inner via closure lookup
|
||||
assert provider.telemetry.acapture_exceptions.call_count == 2
|
||||
|
||||
|
||||
class TestTelemetryFlush:
|
||||
def setup_method(self):
|
||||
self.tinker_provider = MockAsyncTinkerProvider()
|
||||
self.telemetry = Telemetry(self.tinker_provider, session_id="test-session-id")
|
||||
|
||||
def teardown_method(self):
|
||||
if hasattr(self, "telemetry"):
|
||||
self.telemetry.stop()
|
||||
|
||||
def test_flush_empty_queue(self):
|
||||
self.telemetry._queue.clear()
|
||||
with patch.object(
|
||||
self.telemetry, "_send_batch_with_retry", new_callable=AsyncMock
|
||||
) as mock_send:
|
||||
asyncio.run(self.telemetry._flush())
|
||||
mock_send.assert_not_called()
|
||||
|
||||
def test_flush_small_batch(self):
|
||||
for _ in range(5):
|
||||
_ = self.telemetry._log(self.telemetry._session_end_event())
|
||||
with patch.object(
|
||||
self.telemetry, "_send_batch_with_retry", new_callable=AsyncMock
|
||||
) as mock_send:
|
||||
asyncio.run(self.telemetry._flush())
|
||||
assert len(self.telemetry._queue) == 0
|
||||
assert mock_send.call_count == 1
|
||||
|
||||
def test_flush_large_batch(self):
|
||||
for _ in range(MAX_BATCH_SIZE + 10):
|
||||
_ = self.telemetry._log(self.telemetry._session_end_event())
|
||||
with patch.object(
|
||||
self.telemetry, "_send_batch_with_retry", new_callable=AsyncMock
|
||||
) as mock_send:
|
||||
asyncio.run(self.telemetry._flush())
|
||||
assert len(self.telemetry._queue) == 0
|
||||
assert mock_send.call_count == 2
|
||||
|
||||
def test_counters_and_wait(self):
|
||||
initial_push = self.telemetry._push_counter
|
||||
for _ in range(3):
|
||||
_ = self.telemetry._log(self.telemetry._session_end_event())
|
||||
assert self.telemetry._push_counter == initial_push + 3
|
||||
with patch.object(
|
||||
self.telemetry,
|
||||
"_send_batch_with_retry",
|
||||
new_callable=AsyncMock,
|
||||
return_value=TelemetryResponse(status="accepted"),
|
||||
):
|
||||
asyncio.run(self.telemetry._flush())
|
||||
assert self.telemetry._flush_counter >= initial_push + 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_log_exception_sync_from_event_loop_protection(self):
|
||||
with patch("tinker.lib.telemetry._current_loop") as mock_current_loop:
|
||||
mock_current_loop.return_value = Mock()
|
||||
try:
|
||||
raise ValueError("Test error")
|
||||
except ValueError as e:
|
||||
result = self.telemetry.log_exception_sync(e)
|
||||
assert result is True
|
||||
|
||||
|
||||
class TestTelemetryProviderProtocol:
|
||||
def test_telemetry_provider_protocol(self):
|
||||
class ValidProvider:
|
||||
def get_telemetry(self):
|
||||
return None
|
||||
|
||||
class InvalidProvider:
|
||||
pass
|
||||
|
||||
assert isinstance(ValidProvider(), TelemetryProvider)
|
||||
assert not isinstance(InvalidProvider(), TelemetryProvider)
|
||||
|
||||
|
||||
class TestSyncContextManager:
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_batch_uses_sync_context_manager(self):
|
||||
tinker_provider = MockAsyncTinkerProvider()
|
||||
telemetry = Telemetry(tinker_provider, session_id="test-session-id")
|
||||
events = [telemetry._session_start_event()]
|
||||
batch = telemetry._batch(cast(list[TelemetryEvent], events))
|
||||
result = await telemetry._send_batch(batch)
|
||||
tinker_provider._client_cm.__enter__.assert_called_once()
|
||||
tinker_provider._client_cm.__exit__.assert_called_once()
|
||||
tinker_provider.telemetry_send_mock.assert_called_once()
|
||||
assert result.status == "accepted"
|
||||
telemetry.stop()
|
||||
|
||||
|
||||
class TestCrossLoopSafety:
|
||||
@pytest.mark.asyncio
|
||||
async def test_trigger_flush_from_different_loop(self):
|
||||
tinker_provider = MockAsyncTinkerProvider()
|
||||
telemetry = Telemetry(tinker_provider, session_id="test-session-id")
|
||||
tinker_provider.execute_callbacks()
|
||||
|
||||
def trigger_from_thread():
|
||||
telemetry._trigger_flush()
|
||||
|
||||
thread = threading.Thread(target=trigger_from_thread)
|
||||
thread.start()
|
||||
thread.join()
|
||||
tinker_provider.execute_callbacks()
|
||||
assert telemetry._flush_event.is_set()
|
||||
telemetry.stop()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_periodic_flush_with_asyncio_event(self):
|
||||
tinker_provider = MockAsyncTinkerProvider()
|
||||
telemetry = Telemetry(tinker_provider, session_id="test-session-id")
|
||||
tinker_provider.execute_callbacks()
|
||||
telemetry._log(telemetry._session_end_event())
|
||||
with patch.object(telemetry, "_flush", new_callable=AsyncMock) as mock_flush:
|
||||
telemetry._flush_event.set()
|
||||
task = asyncio.create_task(telemetry._periodic_flush())
|
||||
await asyncio.sleep(0.1)
|
||||
task.cancel()
|
||||
with contextlib.suppress(asyncio.CancelledError):
|
||||
await task
|
||||
mock_flush.assert_called()
|
||||
telemetry.stop()
|
||||
0
src/tinker/py.typed
Normal file
0
src/tinker/py.typed
Normal file
103
src/tinker/resources/__init__.py
Normal file
103
src/tinker/resources/__init__.py
Normal file
|
|
@ -0,0 +1,103 @@
|
|||
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
||||
|
||||
from .models import (
|
||||
ModelsResource,
|
||||
AsyncModelsResource,
|
||||
ModelsResourceWithRawResponse,
|
||||
AsyncModelsResourceWithRawResponse,
|
||||
ModelsResourceWithStreamingResponse,
|
||||
AsyncModelsResourceWithStreamingResponse,
|
||||
)
|
||||
from .futures import (
|
||||
FuturesResource,
|
||||
AsyncFuturesResource,
|
||||
FuturesResourceWithRawResponse,
|
||||
AsyncFuturesResourceWithRawResponse,
|
||||
FuturesResourceWithStreamingResponse,
|
||||
AsyncFuturesResourceWithStreamingResponse,
|
||||
)
|
||||
from .service import (
|
||||
ServiceResource,
|
||||
AsyncServiceResource,
|
||||
ServiceResourceWithRawResponse,
|
||||
AsyncServiceResourceWithRawResponse,
|
||||
ServiceResourceWithStreamingResponse,
|
||||
AsyncServiceResourceWithStreamingResponse,
|
||||
)
|
||||
from .weights import (
|
||||
WeightsResource,
|
||||
AsyncWeightsResource,
|
||||
WeightsResourceWithRawResponse,
|
||||
AsyncWeightsResourceWithRawResponse,
|
||||
WeightsResourceWithStreamingResponse,
|
||||
AsyncWeightsResourceWithStreamingResponse,
|
||||
)
|
||||
from .sampling import (
|
||||
SamplingResource,
|
||||
AsyncSamplingResource,
|
||||
SamplingResourceWithRawResponse,
|
||||
AsyncSamplingResourceWithRawResponse,
|
||||
SamplingResourceWithStreamingResponse,
|
||||
AsyncSamplingResourceWithStreamingResponse,
|
||||
)
|
||||
from .training import (
|
||||
TrainingResource,
|
||||
AsyncTrainingResource,
|
||||
TrainingResourceWithRawResponse,
|
||||
AsyncTrainingResourceWithRawResponse,
|
||||
TrainingResourceWithStreamingResponse,
|
||||
AsyncTrainingResourceWithStreamingResponse,
|
||||
)
|
||||
from .telemetry import (
|
||||
TelemetryResource,
|
||||
AsyncTelemetryResource,
|
||||
TelemetryResourceWithRawResponse,
|
||||
AsyncTelemetryResourceWithRawResponse,
|
||||
TelemetryResourceWithStreamingResponse,
|
||||
AsyncTelemetryResourceWithStreamingResponse,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"ServiceResource",
|
||||
"AsyncServiceResource",
|
||||
"ServiceResourceWithRawResponse",
|
||||
"AsyncServiceResourceWithRawResponse",
|
||||
"ServiceResourceWithStreamingResponse",
|
||||
"AsyncServiceResourceWithStreamingResponse",
|
||||
"TrainingResource",
|
||||
"AsyncTrainingResource",
|
||||
"TrainingResourceWithRawResponse",
|
||||
"AsyncTrainingResourceWithRawResponse",
|
||||
"TrainingResourceWithStreamingResponse",
|
||||
"AsyncTrainingResourceWithStreamingResponse",
|
||||
"ModelsResource",
|
||||
"AsyncModelsResource",
|
||||
"ModelsResourceWithRawResponse",
|
||||
"AsyncModelsResourceWithRawResponse",
|
||||
"ModelsResourceWithStreamingResponse",
|
||||
"AsyncModelsResourceWithStreamingResponse",
|
||||
"WeightsResource",
|
||||
"AsyncWeightsResource",
|
||||
"WeightsResourceWithRawResponse",
|
||||
"AsyncWeightsResourceWithRawResponse",
|
||||
"WeightsResourceWithStreamingResponse",
|
||||
"AsyncWeightsResourceWithStreamingResponse",
|
||||
"SamplingResource",
|
||||
"AsyncSamplingResource",
|
||||
"SamplingResourceWithRawResponse",
|
||||
"AsyncSamplingResourceWithRawResponse",
|
||||
"SamplingResourceWithStreamingResponse",
|
||||
"AsyncSamplingResourceWithStreamingResponse",
|
||||
"FuturesResource",
|
||||
"AsyncFuturesResource",
|
||||
"FuturesResourceWithRawResponse",
|
||||
"AsyncFuturesResourceWithRawResponse",
|
||||
"FuturesResourceWithStreamingResponse",
|
||||
"AsyncFuturesResourceWithStreamingResponse",
|
||||
"TelemetryResource",
|
||||
"AsyncTelemetryResource",
|
||||
"TelemetryResourceWithRawResponse",
|
||||
"AsyncTelemetryResourceWithRawResponse",
|
||||
"TelemetryResourceWithStreamingResponse",
|
||||
"AsyncTelemetryResourceWithStreamingResponse",
|
||||
]
|
||||
210
src/tinker/resources/futures.py
Normal file
210
src/tinker/resources/futures.py
Normal file
|
|
@ -0,0 +1,210 @@
|
|||
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, cast
|
||||
|
||||
import httpx
|
||||
|
||||
from ..types import ModelID, RequestID, future_retrieve_params
|
||||
from .._types import NOT_GIVEN, Body, Query, Headers, NotGiven
|
||||
from .._utils import maybe_transform, async_maybe_transform
|
||||
from .._compat import cached_property
|
||||
from .._resource import SyncAPIResource, AsyncAPIResource
|
||||
from .._response import (
|
||||
to_raw_response_wrapper,
|
||||
to_streamed_response_wrapper,
|
||||
async_to_raw_response_wrapper,
|
||||
async_to_streamed_response_wrapper,
|
||||
)
|
||||
from .._base_client import make_request_options
|
||||
from ..types.model_id import ModelID
|
||||
from ..types.request_id import RequestID
|
||||
from ..types.future_retrieve_response import FutureRetrieveResponse
|
||||
|
||||
__all__ = ["FuturesResource", "AsyncFuturesResource"]
|
||||
|
||||
|
||||
class FuturesResource(SyncAPIResource):
|
||||
@cached_property
|
||||
def with_raw_response(self) -> FuturesResourceWithRawResponse:
|
||||
"""
|
||||
This property can be used as a prefix for any HTTP method call to return
|
||||
the raw response object instead of the parsed content.
|
||||
|
||||
For more information, see https://www.github.com/stainless-sdks/tinker-python#accessing-raw-response-data-eg-headers
|
||||
"""
|
||||
return FuturesResourceWithRawResponse(self)
|
||||
|
||||
@cached_property
|
||||
def with_streaming_response(self) -> FuturesResourceWithStreamingResponse:
|
||||
"""
|
||||
An alternative to `.with_raw_response` that doesn't eagerly read the response body.
|
||||
|
||||
For more information, see https://www.github.com/stainless-sdks/tinker-python#with_streaming_response
|
||||
"""
|
||||
return FuturesResourceWithStreamingResponse(self)
|
||||
|
||||
def retrieve(
|
||||
self,
|
||||
*,
|
||||
request_id: RequestID,
|
||||
model_id: ModelID | NotGiven = NOT_GIVEN,
|
||||
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
|
||||
# The extra values given here take precedence over values defined on the client or passed to this method.
|
||||
extra_headers: Headers | None = None,
|
||||
extra_query: Query | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
idempotency_key: str | None = None,
|
||||
) -> FutureRetrieveResponse:
|
||||
"""
|
||||
Retrieves the result of a future by its ID
|
||||
|
||||
Args:
|
||||
extra_headers: Send extra headers
|
||||
|
||||
extra_query: Add additional query parameters to the request
|
||||
|
||||
extra_body: Add additional JSON properties to the request
|
||||
|
||||
timeout: Override the client-level default timeout for this request, in seconds
|
||||
|
||||
idempotency_key: Specify a custom idempotency key for this request
|
||||
"""
|
||||
return cast(
|
||||
FutureRetrieveResponse,
|
||||
self._post(
|
||||
"/api/v1/retrieve_future",
|
||||
body=maybe_transform(
|
||||
{
|
||||
"request_id": request_id,
|
||||
"model_id": model_id,
|
||||
},
|
||||
future_retrieve_params.FutureRetrieveParams,
|
||||
),
|
||||
options=make_request_options(
|
||||
extra_headers=extra_headers,
|
||||
extra_query=extra_query,
|
||||
extra_body=extra_body,
|
||||
timeout=timeout,
|
||||
idempotency_key=idempotency_key,
|
||||
),
|
||||
cast_to=cast(
|
||||
Any, FutureRetrieveResponse
|
||||
), # Union types cannot be passed in as arguments in the type system
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class AsyncFuturesResource(AsyncAPIResource):
|
||||
@cached_property
|
||||
def with_raw_response(self) -> AsyncFuturesResourceWithRawResponse:
|
||||
"""
|
||||
This property can be used as a prefix for any HTTP method call to return
|
||||
the raw response object instead of the parsed content.
|
||||
|
||||
For more information, see https://www.github.com/stainless-sdks/tinker-python#accessing-raw-response-data-eg-headers
|
||||
"""
|
||||
return AsyncFuturesResourceWithRawResponse(self)
|
||||
|
||||
@cached_property
|
||||
def with_streaming_response(self) -> AsyncFuturesResourceWithStreamingResponse:
|
||||
"""
|
||||
An alternative to `.with_raw_response` that doesn't eagerly read the response body.
|
||||
|
||||
For more information, see https://www.github.com/stainless-sdks/tinker-python#with_streaming_response
|
||||
"""
|
||||
return AsyncFuturesResourceWithStreamingResponse(self)
|
||||
|
||||
async def retrieve(
|
||||
self,
|
||||
*,
|
||||
request_id: RequestID,
|
||||
model_id: ModelID | NotGiven = NOT_GIVEN,
|
||||
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
|
||||
# The extra values given here take precedence over values defined on the client or passed to this method.
|
||||
extra_headers: Headers | None = None,
|
||||
extra_query: Query | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
idempotency_key: str | None = None,
|
||||
max_retries: int | NotGiven = NOT_GIVEN,
|
||||
) -> FutureRetrieveResponse:
|
||||
"""
|
||||
Retrieves the result of a future by its ID
|
||||
|
||||
Args:
|
||||
extra_headers: Send extra headers
|
||||
|
||||
extra_query: Add additional query parameters to the request
|
||||
|
||||
extra_body: Add additional JSON properties to the request
|
||||
|
||||
timeout: Override the client-level default timeout for this request, in seconds
|
||||
|
||||
idempotency_key: Specify a custom idempotency key for this request
|
||||
"""
|
||||
options=make_request_options(
|
||||
extra_headers=extra_headers,
|
||||
extra_query=extra_query,
|
||||
extra_body=extra_body,
|
||||
timeout=timeout,
|
||||
idempotency_key=idempotency_key,
|
||||
)
|
||||
if not isinstance(max_retries, NotGiven):
|
||||
options["max_retries"] = max_retries
|
||||
|
||||
return cast(
|
||||
FutureRetrieveResponse,
|
||||
await self._post(
|
||||
"/api/v1/retrieve_future",
|
||||
body=await async_maybe_transform(
|
||||
{
|
||||
"request_id": request_id,
|
||||
"model_id": model_id,
|
||||
},
|
||||
future_retrieve_params.FutureRetrieveParams,
|
||||
),
|
||||
options=options,
|
||||
cast_to=cast(
|
||||
Any, FutureRetrieveResponse
|
||||
), # Union types cannot be passed in as arguments in the type system
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class FuturesResourceWithRawResponse:
|
||||
def __init__(self, futures: FuturesResource) -> None:
|
||||
self._futures = futures
|
||||
|
||||
self.retrieve = to_raw_response_wrapper(
|
||||
futures.retrieve,
|
||||
)
|
||||
|
||||
|
||||
class AsyncFuturesResourceWithRawResponse:
|
||||
def __init__(self, futures: AsyncFuturesResource) -> None:
|
||||
self._futures = futures
|
||||
|
||||
self.retrieve = async_to_raw_response_wrapper(
|
||||
futures.retrieve,
|
||||
)
|
||||
|
||||
|
||||
class FuturesResourceWithStreamingResponse:
|
||||
def __init__(self, futures: FuturesResource) -> None:
|
||||
self._futures = futures
|
||||
|
||||
self.retrieve = to_streamed_response_wrapper(
|
||||
futures.retrieve,
|
||||
)
|
||||
|
||||
|
||||
class AsyncFuturesResourceWithStreamingResponse:
|
||||
def __init__(self, futures: AsyncFuturesResource) -> None:
|
||||
self._futures = futures
|
||||
|
||||
self.retrieve = async_to_streamed_response_wrapper(
|
||||
futures.retrieve,
|
||||
)
|
||||
412
src/tinker/resources/models.py
Normal file
412
src/tinker/resources/models.py
Normal file
|
|
@ -0,0 +1,412 @@
|
|||
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing_extensions import Literal
|
||||
|
||||
import httpx
|
||||
|
||||
from ..types import ModelID, model_create_params, model_unload_params, model_get_info_params
|
||||
from .._types import NOT_GIVEN, Body, Query, Headers, NotGiven
|
||||
from .._utils import maybe_transform, async_maybe_transform
|
||||
from .._compat import cached_property
|
||||
from .._resource import SyncAPIResource, AsyncAPIResource
|
||||
from .._response import (
|
||||
to_raw_response_wrapper,
|
||||
to_streamed_response_wrapper,
|
||||
async_to_raw_response_wrapper,
|
||||
async_to_streamed_response_wrapper,
|
||||
)
|
||||
from .._base_client import make_request_options
|
||||
from ..types.model_id import ModelID
|
||||
from ..types.get_info_response import GetInfoResponse
|
||||
from ..types.lora_config_param import LoraConfigParam
|
||||
from ..types.shared.untyped_api_future import UntypedAPIFuture
|
||||
|
||||
__all__ = ["ModelsResource", "AsyncModelsResource"]
|
||||
|
||||
|
||||
class ModelsResource(SyncAPIResource):
|
||||
@cached_property
|
||||
def with_raw_response(self) -> ModelsResourceWithRawResponse:
|
||||
"""
|
||||
This property can be used as a prefix for any HTTP method call to return
|
||||
the raw response object instead of the parsed content.
|
||||
|
||||
For more information, see https://www.github.com/stainless-sdks/tinker-python#accessing-raw-response-data-eg-headers
|
||||
"""
|
||||
return ModelsResourceWithRawResponse(self)
|
||||
|
||||
@cached_property
|
||||
def with_streaming_response(self) -> ModelsResourceWithStreamingResponse:
|
||||
"""
|
||||
An alternative to `.with_raw_response` that doesn't eagerly read the response body.
|
||||
|
||||
For more information, see https://www.github.com/stainless-sdks/tinker-python#with_streaming_response
|
||||
"""
|
||||
return ModelsResourceWithStreamingResponse(self)
|
||||
|
||||
def create(
|
||||
self,
|
||||
*,
|
||||
base_model: str,
|
||||
lora_config: LoraConfigParam | NotGiven = NOT_GIVEN,
|
||||
type: Literal["create_model"] = "create_model",
|
||||
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
|
||||
# The extra values given here take precedence over values defined on the client or passed to this method.
|
||||
extra_headers: Headers | None = None,
|
||||
extra_query: Query | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
idempotency_key: str | None = None,
|
||||
) -> UntypedAPIFuture:
|
||||
"""Creates a new model.
|
||||
|
||||
Pass a LoRA config to create a new LoRA adapter for the
|
||||
base model.
|
||||
|
||||
Args:
|
||||
extra_headers: Send extra headers
|
||||
|
||||
extra_query: Add additional query parameters to the request
|
||||
|
||||
extra_body: Add additional JSON properties to the request
|
||||
|
||||
timeout: Override the client-level default timeout for this request, in seconds
|
||||
|
||||
idempotency_key: Specify a custom idempotency key for this request
|
||||
"""
|
||||
return self._post(
|
||||
"/api/v1/create_model",
|
||||
body=maybe_transform(
|
||||
{
|
||||
"base_model": base_model,
|
||||
"lora_config": lora_config,
|
||||
"type": type,
|
||||
},
|
||||
model_create_params.ModelCreateParams,
|
||||
),
|
||||
options=make_request_options(
|
||||
extra_headers=extra_headers,
|
||||
extra_query=extra_query,
|
||||
extra_body=extra_body,
|
||||
timeout=timeout,
|
||||
idempotency_key=idempotency_key,
|
||||
),
|
||||
cast_to=UntypedAPIFuture,
|
||||
)
|
||||
|
||||
def get_info(
|
||||
self,
|
||||
*,
|
||||
model_id: ModelID,
|
||||
type: Literal["get_info"] | NotGiven = NOT_GIVEN,
|
||||
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
|
||||
# The extra values given here take precedence over values defined on the client or passed to this method.
|
||||
extra_headers: Headers | None = None,
|
||||
extra_query: Query | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
idempotency_key: str | None = None,
|
||||
) -> GetInfoResponse:
|
||||
"""
|
||||
Retrieves information about the current model
|
||||
|
||||
Args:
|
||||
extra_headers: Send extra headers
|
||||
|
||||
extra_query: Add additional query parameters to the request
|
||||
|
||||
extra_body: Add additional JSON properties to the request
|
||||
|
||||
timeout: Override the client-level default timeout for this request, in seconds
|
||||
|
||||
idempotency_key: Specify a custom idempotency key for this request
|
||||
"""
|
||||
return self._post(
|
||||
"/api/v1/get_info",
|
||||
body=maybe_transform(
|
||||
{
|
||||
"model_id": model_id,
|
||||
"type": type,
|
||||
},
|
||||
model_get_info_params.ModelGetInfoParams,
|
||||
),
|
||||
options=make_request_options(
|
||||
extra_headers=extra_headers,
|
||||
extra_query=extra_query,
|
||||
extra_body=extra_body,
|
||||
timeout=timeout,
|
||||
idempotency_key=idempotency_key,
|
||||
),
|
||||
cast_to=GetInfoResponse,
|
||||
)
|
||||
|
||||
def unload(
|
||||
self,
|
||||
*,
|
||||
model_id: ModelID,
|
||||
type: Literal["unload_model"] | NotGiven = NOT_GIVEN,
|
||||
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
|
||||
# The extra values given here take precedence over values defined on the client or passed to this method.
|
||||
extra_headers: Headers | None = None,
|
||||
extra_query: Query | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
idempotency_key: str | None = None,
|
||||
) -> UntypedAPIFuture:
|
||||
"""
|
||||
Unload the model weights and ends the user's session.
|
||||
|
||||
Args:
|
||||
extra_headers: Send extra headers
|
||||
|
||||
extra_query: Add additional query parameters to the request
|
||||
|
||||
extra_body: Add additional JSON properties to the request
|
||||
|
||||
timeout: Override the client-level default timeout for this request, in seconds
|
||||
|
||||
idempotency_key: Specify a custom idempotency key for this request
|
||||
"""
|
||||
return self._post(
|
||||
"/api/v1/unload_model",
|
||||
body=maybe_transform(
|
||||
{
|
||||
"model_id": model_id,
|
||||
"type": type,
|
||||
},
|
||||
model_unload_params.ModelUnloadParams,
|
||||
),
|
||||
options=make_request_options(
|
||||
extra_headers=extra_headers,
|
||||
extra_query=extra_query,
|
||||
extra_body=extra_body,
|
||||
timeout=timeout,
|
||||
idempotency_key=idempotency_key,
|
||||
),
|
||||
cast_to=UntypedAPIFuture,
|
||||
)
|
||||
|
||||
|
||||
class AsyncModelsResource(AsyncAPIResource):
|
||||
@cached_property
|
||||
def with_raw_response(self) -> AsyncModelsResourceWithRawResponse:
|
||||
"""
|
||||
This property can be used as a prefix for any HTTP method call to return
|
||||
the raw response object instead of the parsed content.
|
||||
|
||||
For more information, see https://www.github.com/stainless-sdks/tinker-python#accessing-raw-response-data-eg-headers
|
||||
"""
|
||||
return AsyncModelsResourceWithRawResponse(self)
|
||||
|
||||
@cached_property
|
||||
def with_streaming_response(self) -> AsyncModelsResourceWithStreamingResponse:
|
||||
"""
|
||||
An alternative to `.with_raw_response` that doesn't eagerly read the response body.
|
||||
|
||||
For more information, see https://www.github.com/stainless-sdks/tinker-python#with_streaming_response
|
||||
"""
|
||||
return AsyncModelsResourceWithStreamingResponse(self)
|
||||
|
||||
async def create(
|
||||
self,
|
||||
*,
|
||||
base_model: str,
|
||||
lora_config: LoraConfigParam | NotGiven = NOT_GIVEN,
|
||||
type: Literal["create_model"] = "create_model",
|
||||
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
|
||||
# The extra values given here take precedence over values defined on the client or passed to this method.
|
||||
extra_headers: Headers | None = None,
|
||||
extra_query: Query | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
idempotency_key: str | None = None,
|
||||
) -> UntypedAPIFuture:
|
||||
"""Creates a new model.
|
||||
|
||||
Pass a LoRA config to create a new LoRA adapter for the
|
||||
base model.
|
||||
|
||||
Args:
|
||||
extra_headers: Send extra headers
|
||||
|
||||
extra_query: Add additional query parameters to the request
|
||||
|
||||
extra_body: Add additional JSON properties to the request
|
||||
|
||||
timeout: Override the client-level default timeout for this request, in seconds
|
||||
|
||||
idempotency_key: Specify a custom idempotency key for this request
|
||||
"""
|
||||
return await self._post(
|
||||
"/api/v1/create_model",
|
||||
body=await async_maybe_transform(
|
||||
{
|
||||
"base_model": base_model,
|
||||
"lora_config": lora_config,
|
||||
"type": type,
|
||||
},
|
||||
model_create_params.ModelCreateParams,
|
||||
),
|
||||
options=make_request_options(
|
||||
extra_headers=extra_headers,
|
||||
extra_query=extra_query,
|
||||
extra_body=extra_body,
|
||||
timeout=timeout,
|
||||
idempotency_key=idempotency_key,
|
||||
),
|
||||
cast_to=UntypedAPIFuture,
|
||||
)
|
||||
|
||||
async def get_info(
|
||||
self,
|
||||
*,
|
||||
model_id: ModelID,
|
||||
type: Literal["get_info"] | NotGiven = NOT_GIVEN,
|
||||
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
|
||||
# The extra values given here take precedence over values defined on the client or passed to this method.
|
||||
extra_headers: Headers | None = None,
|
||||
extra_query: Query | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
idempotency_key: str | None = None,
|
||||
) -> GetInfoResponse:
|
||||
"""
|
||||
Retrieves information about the current model
|
||||
|
||||
Args:
|
||||
extra_headers: Send extra headers
|
||||
|
||||
extra_query: Add additional query parameters to the request
|
||||
|
||||
extra_body: Add additional JSON properties to the request
|
||||
|
||||
timeout: Override the client-level default timeout for this request, in seconds
|
||||
|
||||
idempotency_key: Specify a custom idempotency key for this request
|
||||
"""
|
||||
return await self._post(
|
||||
"/api/v1/get_info",
|
||||
body=await async_maybe_transform(
|
||||
{
|
||||
"model_id": model_id,
|
||||
"type": type,
|
||||
},
|
||||
model_get_info_params.ModelGetInfoParams,
|
||||
),
|
||||
options=make_request_options(
|
||||
extra_headers=extra_headers,
|
||||
extra_query=extra_query,
|
||||
extra_body=extra_body,
|
||||
timeout=timeout,
|
||||
idempotency_key=idempotency_key,
|
||||
),
|
||||
cast_to=GetInfoResponse,
|
||||
)
|
||||
|
||||
async def unload(
|
||||
self,
|
||||
*,
|
||||
model_id: ModelID,
|
||||
type: Literal["unload_model"] | NotGiven = NOT_GIVEN,
|
||||
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
|
||||
# The extra values given here take precedence over values defined on the client or passed to this method.
|
||||
extra_headers: Headers | None = None,
|
||||
extra_query: Query | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
idempotency_key: str | None = None,
|
||||
) -> UntypedAPIFuture:
|
||||
"""
|
||||
Unload the model weights and ends the user's session.
|
||||
|
||||
Args:
|
||||
extra_headers: Send extra headers
|
||||
|
||||
extra_query: Add additional query parameters to the request
|
||||
|
||||
extra_body: Add additional JSON properties to the request
|
||||
|
||||
timeout: Override the client-level default timeout for this request, in seconds
|
||||
|
||||
idempotency_key: Specify a custom idempotency key for this request
|
||||
"""
|
||||
return await self._post(
|
||||
"/api/v1/unload_model",
|
||||
body=await async_maybe_transform(
|
||||
{
|
||||
"model_id": model_id,
|
||||
"type": type,
|
||||
},
|
||||
model_unload_params.ModelUnloadParams,
|
||||
),
|
||||
options=make_request_options(
|
||||
extra_headers=extra_headers,
|
||||
extra_query=extra_query,
|
||||
extra_body=extra_body,
|
||||
timeout=timeout,
|
||||
idempotency_key=idempotency_key,
|
||||
),
|
||||
cast_to=UntypedAPIFuture,
|
||||
)
|
||||
|
||||
|
||||
class ModelsResourceWithRawResponse:
|
||||
def __init__(self, models: ModelsResource) -> None:
|
||||
self._models = models
|
||||
|
||||
self.create = to_raw_response_wrapper(
|
||||
models.create,
|
||||
)
|
||||
self.get_info = to_raw_response_wrapper(
|
||||
models.get_info,
|
||||
)
|
||||
self.unload = to_raw_response_wrapper(
|
||||
models.unload,
|
||||
)
|
||||
|
||||
|
||||
class AsyncModelsResourceWithRawResponse:
|
||||
def __init__(self, models: AsyncModelsResource) -> None:
|
||||
self._models = models
|
||||
|
||||
self.create = async_to_raw_response_wrapper(
|
||||
models.create,
|
||||
)
|
||||
self.get_info = async_to_raw_response_wrapper(
|
||||
models.get_info,
|
||||
)
|
||||
self.unload = async_to_raw_response_wrapper(
|
||||
models.unload,
|
||||
)
|
||||
|
||||
|
||||
class ModelsResourceWithStreamingResponse:
|
||||
def __init__(self, models: ModelsResource) -> None:
|
||||
self._models = models
|
||||
|
||||
self.create = to_streamed_response_wrapper(
|
||||
models.create,
|
||||
)
|
||||
self.get_info = to_streamed_response_wrapper(
|
||||
models.get_info,
|
||||
)
|
||||
self.unload = to_streamed_response_wrapper(
|
||||
models.unload,
|
||||
)
|
||||
|
||||
|
||||
class AsyncModelsResourceWithStreamingResponse:
|
||||
def __init__(self, models: AsyncModelsResource) -> None:
|
||||
self._models = models
|
||||
|
||||
self.create = async_to_streamed_response_wrapper(
|
||||
models.create,
|
||||
)
|
||||
self.get_info = async_to_streamed_response_wrapper(
|
||||
models.get_info,
|
||||
)
|
||||
self.unload = async_to_streamed_response_wrapper(
|
||||
models.unload,
|
||||
)
|
||||
394
src/tinker/resources/sampling.py
Normal file
394
src/tinker/resources/sampling.py
Normal file
|
|
@ -0,0 +1,394 @@
|
|||
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing_extensions import Literal
|
||||
|
||||
import httpx
|
||||
|
||||
from ..types import sampling_sample_params, sampling_asample_params
|
||||
from .._types import NOT_GIVEN, Body, Query, Headers, NotGiven
|
||||
from .._utils import maybe_transform, async_maybe_transform
|
||||
from .._compat import cached_property
|
||||
from .._resource import SyncAPIResource, AsyncAPIResource
|
||||
from .._response import (
|
||||
to_raw_response_wrapper,
|
||||
to_streamed_response_wrapper,
|
||||
async_to_raw_response_wrapper,
|
||||
async_to_streamed_response_wrapper,
|
||||
)
|
||||
from .._base_client import make_request_options
|
||||
from ..types.sample_response import SampleResponse
|
||||
from ..types.model_input_param import ModelInputParam
|
||||
from ..types.sampling_params_param import SamplingParamsParam
|
||||
from ..types.shared.untyped_api_future import UntypedAPIFuture
|
||||
|
||||
__all__ = ["SamplingResource", "AsyncSamplingResource"]
|
||||
|
||||
|
||||
class SamplingResource(SyncAPIResource):
|
||||
@cached_property
|
||||
def with_raw_response(self) -> SamplingResourceWithRawResponse:
|
||||
"""
|
||||
This property can be used as a prefix for any HTTP method call to return
|
||||
the raw response object instead of the parsed content.
|
||||
|
||||
For more information, see https://www.github.com/stainless-sdks/tinker-python#accessing-raw-response-data-eg-headers
|
||||
"""
|
||||
return SamplingResourceWithRawResponse(self)
|
||||
|
||||
@cached_property
|
||||
def with_streaming_response(self) -> SamplingResourceWithStreamingResponse:
|
||||
"""
|
||||
An alternative to `.with_raw_response` that doesn't eagerly read the response body.
|
||||
|
||||
For more information, see https://www.github.com/stainless-sdks/tinker-python#with_streaming_response
|
||||
"""
|
||||
return SamplingResourceWithStreamingResponse(self)
|
||||
|
||||
def asample(
|
||||
self,
|
||||
*,
|
||||
num_samples: int = 1,
|
||||
prompt: ModelInputParam,
|
||||
sampling_params: SamplingParamsParam,
|
||||
base_model: str | NotGiven = NOT_GIVEN,
|
||||
model_path: str | NotGiven = NOT_GIVEN,
|
||||
prompt_logprobs: bool | NotGiven = NOT_GIVEN,
|
||||
type: Literal["sample"] | NotGiven = NOT_GIVEN,
|
||||
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
|
||||
# The extra values given here take precedence over values defined on the client or passed to this method.
|
||||
extra_headers: Headers | None = None,
|
||||
extra_query: Query | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
idempotency_key: str | None = None,
|
||||
) -> UntypedAPIFuture:
|
||||
"""
|
||||
Generates samples from the model using the specified sampling parameters
|
||||
|
||||
Args:
|
||||
num_samples: Number of samples to generate
|
||||
|
||||
base_model: Optional base model name to sample from. Is inferred from model_path, if
|
||||
provided. If sampling against a base model, this is required.
|
||||
|
||||
model_path: Optional tinker:// path to your model weights or LoRA weights. If not provided,
|
||||
samples against the base model.
|
||||
|
||||
prompt_logprobs: If set to `true`, computes and returns logprobs on the prompt tokens. Defaults
|
||||
to false.
|
||||
|
||||
extra_headers: Send extra headers
|
||||
|
||||
extra_query: Add additional query parameters to the request
|
||||
|
||||
extra_body: Add additional JSON properties to the request
|
||||
|
||||
timeout: Override the client-level default timeout for this request, in seconds
|
||||
|
||||
idempotency_key: Specify a custom idempotency key for this request
|
||||
"""
|
||||
return self._post(
|
||||
"/api/v1/asample",
|
||||
body=maybe_transform(
|
||||
{
|
||||
"num_samples": num_samples,
|
||||
"prompt": prompt,
|
||||
"sampling_params": sampling_params,
|
||||
"base_model": base_model,
|
||||
"model_path": model_path,
|
||||
"prompt_logprobs": prompt_logprobs,
|
||||
"type": type,
|
||||
},
|
||||
sampling_asample_params.SamplingAsampleParams,
|
||||
),
|
||||
options=make_request_options(
|
||||
extra_headers=extra_headers,
|
||||
extra_query=extra_query,
|
||||
extra_body=extra_body,
|
||||
timeout=timeout,
|
||||
idempotency_key=idempotency_key,
|
||||
),
|
||||
cast_to=UntypedAPIFuture,
|
||||
)
|
||||
|
||||
def sample(
|
||||
self,
|
||||
*,
|
||||
num_samples: int = 1,
|
||||
prompt: ModelInputParam,
|
||||
sampling_params: SamplingParamsParam,
|
||||
base_model: str | NotGiven = NOT_GIVEN,
|
||||
model_path: str | NotGiven = NOT_GIVEN,
|
||||
prompt_logprobs: bool | NotGiven = NOT_GIVEN,
|
||||
type: Literal["sample"] | NotGiven = NOT_GIVEN,
|
||||
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
|
||||
# The extra values given here take precedence over values defined on the client or passed to this method.
|
||||
extra_headers: Headers | None = None,
|
||||
extra_query: Query | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
idempotency_key: str | None = None,
|
||||
) -> SampleResponse:
|
||||
"""
|
||||
Generates samples from the model using the specified sampling parameters
|
||||
|
||||
Args:
|
||||
num_samples: Number of samples to generate
|
||||
|
||||
base_model: Optional base model name to sample from. Is inferred from model_path, if
|
||||
provided. If sampling against a base model, this is required.
|
||||
|
||||
model_path: Optional tinker:// path to your model weights or LoRA weights. If not provided,
|
||||
samples against the base model.
|
||||
|
||||
prompt_logprobs: If set to `true`, computes and returns logprobs on the prompt tokens. Defaults
|
||||
to false.
|
||||
|
||||
extra_headers: Send extra headers
|
||||
|
||||
extra_query: Add additional query parameters to the request
|
||||
|
||||
extra_body: Add additional JSON properties to the request
|
||||
|
||||
timeout: Override the client-level default timeout for this request, in seconds
|
||||
|
||||
idempotency_key: Specify a custom idempotency key for this request
|
||||
"""
|
||||
return self._post(
|
||||
"/api/v1/sample",
|
||||
body=maybe_transform(
|
||||
{
|
||||
"num_samples": num_samples,
|
||||
"prompt": prompt,
|
||||
"sampling_params": sampling_params,
|
||||
"base_model": base_model,
|
||||
"model_path": model_path,
|
||||
"prompt_logprobs": prompt_logprobs,
|
||||
"type": type,
|
||||
},
|
||||
sampling_sample_params.SamplingSampleParams,
|
||||
),
|
||||
options=make_request_options(
|
||||
extra_headers=extra_headers,
|
||||
extra_query=extra_query,
|
||||
extra_body=extra_body,
|
||||
timeout=timeout,
|
||||
idempotency_key=idempotency_key,
|
||||
),
|
||||
cast_to=SampleResponse,
|
||||
)
|
||||
|
||||
|
||||
class AsyncSamplingResource(AsyncAPIResource):
|
||||
@cached_property
|
||||
def with_raw_response(self) -> AsyncSamplingResourceWithRawResponse:
|
||||
"""
|
||||
This property can be used as a prefix for any HTTP method call to return
|
||||
the raw response object instead of the parsed content.
|
||||
|
||||
For more information, see https://www.github.com/stainless-sdks/tinker-python#accessing-raw-response-data-eg-headers
|
||||
"""
|
||||
return AsyncSamplingResourceWithRawResponse(self)
|
||||
|
||||
@cached_property
|
||||
def with_streaming_response(self) -> AsyncSamplingResourceWithStreamingResponse:
|
||||
"""
|
||||
An alternative to `.with_raw_response` that doesn't eagerly read the response body.
|
||||
|
||||
For more information, see https://www.github.com/stainless-sdks/tinker-python#with_streaming_response
|
||||
"""
|
||||
return AsyncSamplingResourceWithStreamingResponse(self)
|
||||
|
||||
async def asample(
|
||||
self,
|
||||
*,
|
||||
num_samples: int = 1,
|
||||
prompt: ModelInputParam,
|
||||
sampling_params: SamplingParamsParam,
|
||||
base_model: str | NotGiven = NOT_GIVEN,
|
||||
model_path: str | NotGiven = NOT_GIVEN,
|
||||
prompt_logprobs: bool | NotGiven = NOT_GIVEN,
|
||||
type: Literal["sample"] | NotGiven = NOT_GIVEN,
|
||||
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
|
||||
# The extra values given here take precedence over values defined on the client or passed to this method.
|
||||
extra_headers: Headers | None = None,
|
||||
extra_query: Query | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
idempotency_key: str | None = None,
|
||||
max_retries: int | NotGiven = NOT_GIVEN,
|
||||
) -> UntypedAPIFuture:
|
||||
"""
|
||||
Generates samples from the model using the specified sampling parameters
|
||||
|
||||
Args:
|
||||
num_samples: Number of samples to generate
|
||||
|
||||
base_model: Optional base model name to sample from. Is inferred from model_path, if
|
||||
provided. If sampling against a base model, this is required.
|
||||
|
||||
model_path: Optional tinker:// path to your model weights or LoRA weights. If not provided,
|
||||
samples against the base model.
|
||||
|
||||
prompt_logprobs: If set to `true`, computes and returns logprobs on the prompt tokens. Defaults
|
||||
to false.
|
||||
|
||||
extra_headers: Send extra headers
|
||||
|
||||
extra_query: Add additional query parameters to the request
|
||||
|
||||
extra_body: Add additional JSON properties to the request
|
||||
|
||||
timeout: Override the client-level default timeout for this request, in seconds
|
||||
|
||||
idempotency_key: Specify a custom idempotency key for this request
|
||||
"""
|
||||
options = make_request_options(
|
||||
extra_headers=extra_headers,
|
||||
extra_query=extra_query,
|
||||
extra_body=extra_body,
|
||||
timeout=timeout,
|
||||
idempotency_key=idempotency_key,
|
||||
)
|
||||
if max_retries is not NOT_GIVEN:
|
||||
options["max_retries"] = max_retries
|
||||
|
||||
return await self._post(
|
||||
"/api/v1/asample",
|
||||
body=await async_maybe_transform(
|
||||
{
|
||||
"num_samples": num_samples,
|
||||
"prompt": prompt,
|
||||
"sampling_params": sampling_params,
|
||||
"base_model": base_model,
|
||||
"model_path": model_path,
|
||||
"prompt_logprobs": prompt_logprobs,
|
||||
"type": type,
|
||||
},
|
||||
sampling_asample_params.SamplingAsampleParams,
|
||||
),
|
||||
options=options,
|
||||
cast_to=UntypedAPIFuture,
|
||||
)
|
||||
|
||||
async def sample(
|
||||
self,
|
||||
*,
|
||||
num_samples: int = 1,
|
||||
prompt: ModelInputParam,
|
||||
sampling_params: SamplingParamsParam,
|
||||
base_model: str | NotGiven = NOT_GIVEN,
|
||||
model_path: str | NotGiven = NOT_GIVEN,
|
||||
prompt_logprobs: bool | NotGiven = NOT_GIVEN,
|
||||
type: Literal["sample"] | NotGiven = NOT_GIVEN,
|
||||
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
|
||||
# The extra values given here take precedence over values defined on the client or passed to this method.
|
||||
extra_headers: Headers | None = None,
|
||||
extra_query: Query | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
idempotency_key: str | None = None,
|
||||
max_retries: int | NotGiven = NOT_GIVEN,
|
||||
) -> SampleResponse:
|
||||
"""
|
||||
Generates samples from the model using the specified sampling parameters
|
||||
|
||||
Args:
|
||||
num_samples: Number of samples to generate
|
||||
|
||||
base_model: Optional base model name to sample from. Is inferred from model_path, if
|
||||
provided. If sampling against a base model, this is required.
|
||||
|
||||
model_path: Optional tinker:// path to your model weights or LoRA weights. If not provided,
|
||||
samples against the base model.
|
||||
|
||||
prompt_logprobs: If set to `true`, computes and returns logprobs on the prompt tokens. Defaults
|
||||
to false.
|
||||
|
||||
extra_headers: Send extra headers
|
||||
|
||||
extra_query: Add additional query parameters to the request
|
||||
|
||||
extra_body: Add additional JSON properties to the request
|
||||
|
||||
timeout: Override the client-level default timeout for this request, in seconds
|
||||
|
||||
idempotency_key: Specify a custom idempotency key for this request
|
||||
"""
|
||||
options = make_request_options(
|
||||
extra_headers=extra_headers,
|
||||
extra_query=extra_query,
|
||||
extra_body=extra_body,
|
||||
timeout=timeout,
|
||||
idempotency_key=idempotency_key,
|
||||
)
|
||||
if max_retries is not NOT_GIVEN:
|
||||
options["max_retries"] = max_retries
|
||||
|
||||
return await self._post(
|
||||
"/api/v1/sample",
|
||||
body=await async_maybe_transform(
|
||||
{
|
||||
"num_samples": num_samples,
|
||||
"prompt": prompt,
|
||||
"sampling_params": sampling_params,
|
||||
"base_model": base_model,
|
||||
"model_path": model_path,
|
||||
"prompt_logprobs": prompt_logprobs,
|
||||
"type": type,
|
||||
},
|
||||
sampling_sample_params.SamplingSampleParams,
|
||||
),
|
||||
options=options,
|
||||
cast_to=SampleResponse,
|
||||
)
|
||||
|
||||
|
||||
class SamplingResourceWithRawResponse:
|
||||
def __init__(self, sampling: SamplingResource) -> None:
|
||||
self._sampling = sampling
|
||||
|
||||
self.asample = to_raw_response_wrapper(
|
||||
sampling.asample,
|
||||
)
|
||||
self.sample = to_raw_response_wrapper(
|
||||
sampling.sample,
|
||||
)
|
||||
|
||||
|
||||
class AsyncSamplingResourceWithRawResponse:
|
||||
def __init__(self, sampling: AsyncSamplingResource) -> None:
|
||||
self._sampling = sampling
|
||||
|
||||
self.asample = async_to_raw_response_wrapper(
|
||||
sampling.asample,
|
||||
)
|
||||
self.sample = async_to_raw_response_wrapper(
|
||||
sampling.sample,
|
||||
)
|
||||
|
||||
|
||||
class SamplingResourceWithStreamingResponse:
|
||||
def __init__(self, sampling: SamplingResource) -> None:
|
||||
self._sampling = sampling
|
||||
|
||||
self.asample = to_streamed_response_wrapper(
|
||||
sampling.asample,
|
||||
)
|
||||
self.sample = to_streamed_response_wrapper(
|
||||
sampling.sample,
|
||||
)
|
||||
|
||||
|
||||
class AsyncSamplingResourceWithStreamingResponse:
|
||||
def __init__(self, sampling: AsyncSamplingResource) -> None:
|
||||
self._sampling = sampling
|
||||
|
||||
self.asample = async_to_streamed_response_wrapper(
|
||||
sampling.asample,
|
||||
)
|
||||
self.sample = async_to_streamed_response_wrapper(
|
||||
sampling.sample,
|
||||
)
|
||||
186
src/tinker/resources/service.py
Normal file
186
src/tinker/resources/service.py
Normal file
|
|
@ -0,0 +1,186 @@
|
|||
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import httpx
|
||||
|
||||
from .._types import NOT_GIVEN, Body, Query, Headers, NotGiven
|
||||
from .._compat import cached_property
|
||||
from .._resource import SyncAPIResource, AsyncAPIResource
|
||||
from .._response import (
|
||||
to_raw_response_wrapper,
|
||||
to_streamed_response_wrapper,
|
||||
async_to_raw_response_wrapper,
|
||||
async_to_streamed_response_wrapper,
|
||||
)
|
||||
from .._base_client import make_request_options
|
||||
from ..types.health_response import HealthResponse
|
||||
from ..types.get_server_capabilities_response import GetServerCapabilitiesResponse
|
||||
|
||||
__all__ = ["ServiceResource", "AsyncServiceResource"]
|
||||
|
||||
|
||||
class ServiceResource(SyncAPIResource):
|
||||
@cached_property
|
||||
def with_raw_response(self) -> ServiceResourceWithRawResponse:
|
||||
"""
|
||||
This property can be used as a prefix for any HTTP method call to return
|
||||
the raw response object instead of the parsed content.
|
||||
|
||||
For more information, see https://www.github.com/stainless-sdks/tinker-python#accessing-raw-response-data-eg-headers
|
||||
"""
|
||||
return ServiceResourceWithRawResponse(self)
|
||||
|
||||
@cached_property
|
||||
def with_streaming_response(self) -> ServiceResourceWithStreamingResponse:
|
||||
"""
|
||||
An alternative to `.with_raw_response` that doesn't eagerly read the response body.
|
||||
|
||||
For more information, see https://www.github.com/stainless-sdks/tinker-python#with_streaming_response
|
||||
"""
|
||||
return ServiceResourceWithStreamingResponse(self)
|
||||
|
||||
def get_server_capabilities(
|
||||
self,
|
||||
*,
|
||||
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
|
||||
# The extra values given here take precedence over values defined on the client or passed to this method.
|
||||
extra_headers: Headers | None = None,
|
||||
extra_query: Query | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> GetServerCapabilitiesResponse:
|
||||
"""Retrieves information about supported models and server capabilities"""
|
||||
return self._get(
|
||||
"/api/v1/get_server_capabilities",
|
||||
options=make_request_options(
|
||||
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
|
||||
),
|
||||
cast_to=GetServerCapabilitiesResponse,
|
||||
)
|
||||
|
||||
def health_check(
|
||||
self,
|
||||
*,
|
||||
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
|
||||
# The extra values given here take precedence over values defined on the client or passed to this method.
|
||||
extra_headers: Headers | None = None,
|
||||
extra_query: Query | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> HealthResponse:
|
||||
"""Checks if the API server is ready"""
|
||||
return self._get(
|
||||
"/api/v1/healthz",
|
||||
options=make_request_options(
|
||||
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
|
||||
),
|
||||
cast_to=HealthResponse,
|
||||
)
|
||||
|
||||
|
||||
class AsyncServiceResource(AsyncAPIResource):
|
||||
@cached_property
|
||||
def with_raw_response(self) -> AsyncServiceResourceWithRawResponse:
|
||||
"""
|
||||
This property can be used as a prefix for any HTTP method call to return
|
||||
the raw response object instead of the parsed content.
|
||||
|
||||
For more information, see https://www.github.com/stainless-sdks/tinker-python#accessing-raw-response-data-eg-headers
|
||||
"""
|
||||
return AsyncServiceResourceWithRawResponse(self)
|
||||
|
||||
@cached_property
|
||||
def with_streaming_response(self) -> AsyncServiceResourceWithStreamingResponse:
|
||||
"""
|
||||
An alternative to `.with_raw_response` that doesn't eagerly read the response body.
|
||||
|
||||
For more information, see https://www.github.com/stainless-sdks/tinker-python#with_streaming_response
|
||||
"""
|
||||
return AsyncServiceResourceWithStreamingResponse(self)
|
||||
|
||||
async def get_server_capabilities(
|
||||
self,
|
||||
*,
|
||||
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
|
||||
# The extra values given here take precedence over values defined on the client or passed to this method.
|
||||
extra_headers: Headers | None = None,
|
||||
extra_query: Query | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> GetServerCapabilitiesResponse:
|
||||
"""Retrieves information about supported models and server capabilities"""
|
||||
return await self._get(
|
||||
"/api/v1/get_server_capabilities",
|
||||
options=make_request_options(
|
||||
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
|
||||
),
|
||||
cast_to=GetServerCapabilitiesResponse,
|
||||
)
|
||||
|
||||
async def health_check(
|
||||
self,
|
||||
*,
|
||||
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
|
||||
# The extra values given here take precedence over values defined on the client or passed to this method.
|
||||
extra_headers: Headers | None = None,
|
||||
extra_query: Query | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> HealthResponse:
|
||||
"""Checks if the API server is ready"""
|
||||
return await self._get(
|
||||
"/api/v1/healthz",
|
||||
options=make_request_options(
|
||||
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
|
||||
),
|
||||
cast_to=HealthResponse,
|
||||
)
|
||||
|
||||
|
||||
class ServiceResourceWithRawResponse:
|
||||
def __init__(self, service: ServiceResource) -> None:
|
||||
self._service = service
|
||||
|
||||
self.get_server_capabilities = to_raw_response_wrapper(
|
||||
service.get_server_capabilities,
|
||||
)
|
||||
self.health_check = to_raw_response_wrapper(
|
||||
service.health_check,
|
||||
)
|
||||
|
||||
|
||||
class AsyncServiceResourceWithRawResponse:
|
||||
def __init__(self, service: AsyncServiceResource) -> None:
|
||||
self._service = service
|
||||
|
||||
self.get_server_capabilities = async_to_raw_response_wrapper(
|
||||
service.get_server_capabilities,
|
||||
)
|
||||
self.health_check = async_to_raw_response_wrapper(
|
||||
service.health_check,
|
||||
)
|
||||
|
||||
|
||||
class ServiceResourceWithStreamingResponse:
|
||||
def __init__(self, service: ServiceResource) -> None:
|
||||
self._service = service
|
||||
|
||||
self.get_server_capabilities = to_streamed_response_wrapper(
|
||||
service.get_server_capabilities,
|
||||
)
|
||||
self.health_check = to_streamed_response_wrapper(
|
||||
service.health_check,
|
||||
)
|
||||
|
||||
|
||||
class AsyncServiceResourceWithStreamingResponse:
|
||||
def __init__(self, service: AsyncServiceResource) -> None:
|
||||
self._service = service
|
||||
|
||||
self.get_server_capabilities = async_to_streamed_response_wrapper(
|
||||
service.get_server_capabilities,
|
||||
)
|
||||
self.health_check = async_to_streamed_response_wrapper(
|
||||
service.health_check,
|
||||
)
|
||||
210
src/tinker/resources/telemetry.py
Normal file
210
src/tinker/resources/telemetry.py
Normal file
|
|
@ -0,0 +1,210 @@
|
|||
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Iterable
|
||||
|
||||
import httpx
|
||||
|
||||
from ..types import telemetry_send_params
|
||||
from .._types import NOT_GIVEN, Body, Query, Headers, NotGiven
|
||||
from .._utils import maybe_transform, async_maybe_transform
|
||||
from .._compat import cached_property
|
||||
from .._resource import SyncAPIResource, AsyncAPIResource
|
||||
from .._response import (
|
||||
to_raw_response_wrapper,
|
||||
to_streamed_response_wrapper,
|
||||
async_to_raw_response_wrapper,
|
||||
async_to_streamed_response_wrapper,
|
||||
)
|
||||
from .._base_client import make_request_options
|
||||
from ..types.telemetry_response import TelemetryResponse
|
||||
from ..types.telemetry_event_param import TelemetryEventParam
|
||||
|
||||
__all__ = ["TelemetryResource", "AsyncTelemetryResource"]
|
||||
|
||||
|
||||
class TelemetryResource(SyncAPIResource):
|
||||
@cached_property
|
||||
def with_raw_response(self) -> TelemetryResourceWithRawResponse:
|
||||
"""
|
||||
This property can be used as a prefix for any HTTP method call to return
|
||||
the raw response object instead of the parsed content.
|
||||
|
||||
For more information, see https://www.github.com/stainless-sdks/tinker-python#accessing-raw-response-data-eg-headers
|
||||
"""
|
||||
return TelemetryResourceWithRawResponse(self)
|
||||
|
||||
@cached_property
|
||||
def with_streaming_response(self) -> TelemetryResourceWithStreamingResponse:
|
||||
"""
|
||||
An alternative to `.with_raw_response` that doesn't eagerly read the response body.
|
||||
|
||||
For more information, see https://www.github.com/stainless-sdks/tinker-python#with_streaming_response
|
||||
"""
|
||||
return TelemetryResourceWithStreamingResponse(self)
|
||||
|
||||
def send(
|
||||
self,
|
||||
*,
|
||||
events: Iterable[TelemetryEventParam],
|
||||
platform: str,
|
||||
sdk_version: str,
|
||||
session_id: str,
|
||||
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
|
||||
# The extra values given here take precedence over values defined on the client or passed to this method.
|
||||
extra_headers: Headers | None = None,
|
||||
extra_query: Query | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
idempotency_key: str | None = None,
|
||||
) -> TelemetryResponse:
|
||||
"""
|
||||
Accepts batches of SDK telemetry events for analytics and diagnostics
|
||||
|
||||
Args:
|
||||
platform: Host platform name
|
||||
|
||||
sdk_version: SDK version string
|
||||
|
||||
extra_headers: Send extra headers
|
||||
|
||||
extra_query: Add additional query parameters to the request
|
||||
|
||||
extra_body: Add additional JSON properties to the request
|
||||
|
||||
timeout: Override the client-level default timeout for this request, in seconds
|
||||
|
||||
idempotency_key: Specify a custom idempotency key for this request
|
||||
"""
|
||||
return self._post(
|
||||
"/api/v1/telemetry",
|
||||
body=maybe_transform(
|
||||
{
|
||||
"events": events,
|
||||
"platform": platform,
|
||||
"sdk_version": sdk_version,
|
||||
"session_id": session_id,
|
||||
},
|
||||
telemetry_send_params.TelemetrySendParams,
|
||||
),
|
||||
options=make_request_options(
|
||||
extra_headers=extra_headers,
|
||||
extra_query=extra_query,
|
||||
extra_body=extra_body,
|
||||
timeout=timeout,
|
||||
idempotency_key=idempotency_key,
|
||||
),
|
||||
cast_to=TelemetryResponse,
|
||||
)
|
||||
|
||||
|
||||
class AsyncTelemetryResource(AsyncAPIResource):
|
||||
@cached_property
|
||||
def with_raw_response(self) -> AsyncTelemetryResourceWithRawResponse:
|
||||
"""
|
||||
This property can be used as a prefix for any HTTP method call to return
|
||||
the raw response object instead of the parsed content.
|
||||
|
||||
For more information, see https://www.github.com/stainless-sdks/tinker-python#accessing-raw-response-data-eg-headers
|
||||
"""
|
||||
return AsyncTelemetryResourceWithRawResponse(self)
|
||||
|
||||
@cached_property
|
||||
def with_streaming_response(self) -> AsyncTelemetryResourceWithStreamingResponse:
|
||||
"""
|
||||
An alternative to `.with_raw_response` that doesn't eagerly read the response body.
|
||||
|
||||
For more information, see https://www.github.com/stainless-sdks/tinker-python#with_streaming_response
|
||||
"""
|
||||
return AsyncTelemetryResourceWithStreamingResponse(self)
|
||||
|
||||
async def send(
|
||||
self,
|
||||
*,
|
||||
events: Iterable[TelemetryEventParam],
|
||||
platform: str,
|
||||
sdk_version: str,
|
||||
session_id: str,
|
||||
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
|
||||
# The extra values given here take precedence over values defined on the client or passed to this method.
|
||||
extra_headers: Headers | None = None,
|
||||
extra_query: Query | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
idempotency_key: str | None = None,
|
||||
) -> TelemetryResponse:
|
||||
"""
|
||||
Accepts batches of SDK telemetry events for analytics and diagnostics
|
||||
|
||||
Args:
|
||||
platform: Host platform name
|
||||
|
||||
sdk_version: SDK version string
|
||||
|
||||
extra_headers: Send extra headers
|
||||
|
||||
extra_query: Add additional query parameters to the request
|
||||
|
||||
extra_body: Add additional JSON properties to the request
|
||||
|
||||
timeout: Override the client-level default timeout for this request, in seconds
|
||||
|
||||
idempotency_key: Specify a custom idempotency key for this request
|
||||
"""
|
||||
return await self._post(
|
||||
"/api/v1/telemetry",
|
||||
body=await async_maybe_transform(
|
||||
{
|
||||
"events": events,
|
||||
"platform": platform,
|
||||
"sdk_version": sdk_version,
|
||||
"session_id": session_id,
|
||||
},
|
||||
telemetry_send_params.TelemetrySendParams,
|
||||
),
|
||||
options=make_request_options(
|
||||
extra_headers=extra_headers,
|
||||
extra_query=extra_query,
|
||||
extra_body=extra_body,
|
||||
timeout=timeout,
|
||||
idempotency_key=idempotency_key,
|
||||
),
|
||||
cast_to=TelemetryResponse,
|
||||
)
|
||||
|
||||
|
||||
class TelemetryResourceWithRawResponse:
|
||||
def __init__(self, telemetry: TelemetryResource) -> None:
|
||||
self._telemetry = telemetry
|
||||
|
||||
self.send = to_raw_response_wrapper(
|
||||
telemetry.send,
|
||||
)
|
||||
|
||||
|
||||
class AsyncTelemetryResourceWithRawResponse:
|
||||
def __init__(self, telemetry: AsyncTelemetryResource) -> None:
|
||||
self._telemetry = telemetry
|
||||
|
||||
self.send = async_to_raw_response_wrapper(
|
||||
telemetry.send,
|
||||
)
|
||||
|
||||
|
||||
class TelemetryResourceWithStreamingResponse:
|
||||
def __init__(self, telemetry: TelemetryResource) -> None:
|
||||
self._telemetry = telemetry
|
||||
|
||||
self.send = to_streamed_response_wrapper(
|
||||
telemetry.send,
|
||||
)
|
||||
|
||||
|
||||
class AsyncTelemetryResourceWithStreamingResponse:
|
||||
def __init__(self, telemetry: AsyncTelemetryResource) -> None:
|
||||
self._telemetry = telemetry
|
||||
|
||||
self.send = async_to_streamed_response_wrapper(
|
||||
telemetry.send,
|
||||
)
|
||||
412
src/tinker/resources/training.py
Normal file
412
src/tinker/resources/training.py
Normal file
|
|
@ -0,0 +1,412 @@
|
|||
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing_extensions import Literal
|
||||
|
||||
import httpx
|
||||
|
||||
from ..types import (
|
||||
ModelID,
|
||||
training_forward_params,
|
||||
training_optim_step_params,
|
||||
training_forward_backward_params,
|
||||
)
|
||||
from .._types import NOT_GIVEN, Body, Query, Headers, NotGiven
|
||||
from .._utils import maybe_transform, async_maybe_transform
|
||||
from .._compat import cached_property
|
||||
from .._resource import SyncAPIResource, AsyncAPIResource
|
||||
from .._response import (
|
||||
to_raw_response_wrapper,
|
||||
to_streamed_response_wrapper,
|
||||
async_to_raw_response_wrapper,
|
||||
async_to_streamed_response_wrapper,
|
||||
)
|
||||
from .._base_client import make_request_options
|
||||
from ..types.model_id import ModelID
|
||||
from ..types.shared.untyped_api_future import UntypedAPIFuture
|
||||
from ..types.forward_backward_input_param import ForwardBackwardInputParam
|
||||
|
||||
__all__ = ["TrainingResource", "AsyncTrainingResource"]
|
||||
|
||||
|
||||
class TrainingResource(SyncAPIResource):
|
||||
@cached_property
|
||||
def with_raw_response(self) -> TrainingResourceWithRawResponse:
|
||||
"""
|
||||
This property can be used as a prefix for any HTTP method call to return
|
||||
the raw response object instead of the parsed content.
|
||||
|
||||
For more information, see https://www.github.com/stainless-sdks/tinker-python#accessing-raw-response-data-eg-headers
|
||||
"""
|
||||
return TrainingResourceWithRawResponse(self)
|
||||
|
||||
@cached_property
|
||||
def with_streaming_response(self) -> TrainingResourceWithStreamingResponse:
|
||||
"""
|
||||
An alternative to `.with_raw_response` that doesn't eagerly read the response body.
|
||||
|
||||
For more information, see https://www.github.com/stainless-sdks/tinker-python#with_streaming_response
|
||||
"""
|
||||
return TrainingResourceWithStreamingResponse(self)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
*,
|
||||
forward_input: ForwardBackwardInputParam,
|
||||
model_id: ModelID,
|
||||
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
|
||||
# The extra values given here take precedence over values defined on the client or passed to this method.
|
||||
extra_headers: Headers | None = None,
|
||||
extra_query: Query | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
idempotency_key: str | None = None,
|
||||
) -> UntypedAPIFuture:
|
||||
"""
|
||||
Performs a forward pass through the model
|
||||
|
||||
Args:
|
||||
extra_headers: Send extra headers
|
||||
|
||||
extra_query: Add additional query parameters to the request
|
||||
|
||||
extra_body: Add additional JSON properties to the request
|
||||
|
||||
timeout: Override the client-level default timeout for this request, in seconds
|
||||
|
||||
idempotency_key: Specify a custom idempotency key for this request
|
||||
"""
|
||||
return self._post(
|
||||
"/api/v1/forward",
|
||||
body=maybe_transform(
|
||||
{
|
||||
"forward_input": forward_input,
|
||||
"model_id": model_id,
|
||||
},
|
||||
training_forward_params.TrainingForwardParams,
|
||||
),
|
||||
options=make_request_options(
|
||||
extra_headers=extra_headers,
|
||||
extra_query=extra_query,
|
||||
extra_body=extra_body,
|
||||
timeout=timeout,
|
||||
idempotency_key=idempotency_key,
|
||||
),
|
||||
cast_to=UntypedAPIFuture,
|
||||
)
|
||||
|
||||
def forward_backward(
|
||||
self,
|
||||
*,
|
||||
forward_backward_input: ForwardBackwardInputParam,
|
||||
model_id: ModelID,
|
||||
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
|
||||
# The extra values given here take precedence over values defined on the client or passed to this method.
|
||||
extra_headers: Headers | None = None,
|
||||
extra_query: Query | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
idempotency_key: str | None = None,
|
||||
) -> UntypedAPIFuture:
|
||||
"""
|
||||
Performs a forward and backward pass through the model
|
||||
|
||||
Args:
|
||||
extra_headers: Send extra headers
|
||||
|
||||
extra_query: Add additional query parameters to the request
|
||||
|
||||
extra_body: Add additional JSON properties to the request
|
||||
|
||||
timeout: Override the client-level default timeout for this request, in seconds
|
||||
|
||||
idempotency_key: Specify a custom idempotency key for this request
|
||||
"""
|
||||
return self._post(
|
||||
"/api/v1/forward_backward",
|
||||
body=maybe_transform(
|
||||
{
|
||||
"forward_backward_input": forward_backward_input,
|
||||
"model_id": model_id,
|
||||
},
|
||||
training_forward_backward_params.TrainingForwardBackwardParams,
|
||||
),
|
||||
options=make_request_options(
|
||||
extra_headers=extra_headers,
|
||||
extra_query=extra_query,
|
||||
extra_body=extra_body,
|
||||
timeout=timeout,
|
||||
idempotency_key=idempotency_key,
|
||||
),
|
||||
cast_to=UntypedAPIFuture,
|
||||
)
|
||||
|
||||
def optim_step(
|
||||
self,
|
||||
*,
|
||||
adam_params: training_optim_step_params.AdamParams,
|
||||
model_id: ModelID,
|
||||
type: Literal["optim_step"] | NotGiven = NOT_GIVEN,
|
||||
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
|
||||
# The extra values given here take precedence over values defined on the client or passed to this method.
|
||||
extra_headers: Headers | None = None,
|
||||
extra_query: Query | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
idempotency_key: str | None = None,
|
||||
) -> UntypedAPIFuture:
|
||||
"""
|
||||
Performs an optimization step using AdamW optimizer
|
||||
|
||||
Args:
|
||||
extra_headers: Send extra headers
|
||||
|
||||
extra_query: Add additional query parameters to the request
|
||||
|
||||
extra_body: Add additional JSON properties to the request
|
||||
|
||||
timeout: Override the client-level default timeout for this request, in seconds
|
||||
|
||||
idempotency_key: Specify a custom idempotency key for this request
|
||||
"""
|
||||
return self._post(
|
||||
"/api/v1/optim_step",
|
||||
body=maybe_transform(
|
||||
{
|
||||
"adam_params": adam_params,
|
||||
"model_id": model_id,
|
||||
"type": type,
|
||||
},
|
||||
training_optim_step_params.TrainingOptimStepParams,
|
||||
),
|
||||
options=make_request_options(
|
||||
extra_headers=extra_headers,
|
||||
extra_query=extra_query,
|
||||
extra_body=extra_body,
|
||||
timeout=timeout,
|
||||
idempotency_key=idempotency_key,
|
||||
),
|
||||
cast_to=UntypedAPIFuture,
|
||||
)
|
||||
|
||||
|
||||
class AsyncTrainingResource(AsyncAPIResource):
|
||||
@cached_property
|
||||
def with_raw_response(self) -> AsyncTrainingResourceWithRawResponse:
|
||||
"""
|
||||
This property can be used as a prefix for any HTTP method call to return
|
||||
the raw response object instead of the parsed content.
|
||||
|
||||
For more information, see https://www.github.com/stainless-sdks/tinker-python#accessing-raw-response-data-eg-headers
|
||||
"""
|
||||
return AsyncTrainingResourceWithRawResponse(self)
|
||||
|
||||
@cached_property
|
||||
def with_streaming_response(self) -> AsyncTrainingResourceWithStreamingResponse:
|
||||
"""
|
||||
An alternative to `.with_raw_response` that doesn't eagerly read the response body.
|
||||
|
||||
For more information, see https://www.github.com/stainless-sdks/tinker-python#with_streaming_response
|
||||
"""
|
||||
return AsyncTrainingResourceWithStreamingResponse(self)
|
||||
|
||||
async def forward(
|
||||
self,
|
||||
*,
|
||||
forward_input: ForwardBackwardInputParam,
|
||||
model_id: ModelID,
|
||||
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
|
||||
# The extra values given here take precedence over values defined on the client or passed to this method.
|
||||
extra_headers: Headers | None = None,
|
||||
extra_query: Query | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
idempotency_key: str | None = None,
|
||||
) -> UntypedAPIFuture:
|
||||
"""
|
||||
Performs a forward pass through the model
|
||||
|
||||
Args:
|
||||
extra_headers: Send extra headers
|
||||
|
||||
extra_query: Add additional query parameters to the request
|
||||
|
||||
extra_body: Add additional JSON properties to the request
|
||||
|
||||
timeout: Override the client-level default timeout for this request, in seconds
|
||||
|
||||
idempotency_key: Specify a custom idempotency key for this request
|
||||
"""
|
||||
return await self._post(
|
||||
"/api/v1/forward",
|
||||
body=await async_maybe_transform(
|
||||
{
|
||||
"forward_input": forward_input,
|
||||
"model_id": model_id,
|
||||
},
|
||||
training_forward_params.TrainingForwardParams,
|
||||
),
|
||||
options=make_request_options(
|
||||
extra_headers=extra_headers,
|
||||
extra_query=extra_query,
|
||||
extra_body=extra_body,
|
||||
timeout=timeout,
|
||||
idempotency_key=idempotency_key,
|
||||
),
|
||||
cast_to=UntypedAPIFuture,
|
||||
)
|
||||
|
||||
async def forward_backward(
|
||||
self,
|
||||
*,
|
||||
forward_backward_input: ForwardBackwardInputParam,
|
||||
model_id: ModelID,
|
||||
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
|
||||
# The extra values given here take precedence over values defined on the client or passed to this method.
|
||||
extra_headers: Headers | None = None,
|
||||
extra_query: Query | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
idempotency_key: str | None = None,
|
||||
) -> UntypedAPIFuture:
|
||||
"""
|
||||
Performs a forward and backward pass through the model
|
||||
|
||||
Args:
|
||||
extra_headers: Send extra headers
|
||||
|
||||
extra_query: Add additional query parameters to the request
|
||||
|
||||
extra_body: Add additional JSON properties to the request
|
||||
|
||||
timeout: Override the client-level default timeout for this request, in seconds
|
||||
|
||||
idempotency_key: Specify a custom idempotency key for this request
|
||||
"""
|
||||
return await self._post(
|
||||
"/api/v1/forward_backward",
|
||||
body=await async_maybe_transform(
|
||||
{
|
||||
"forward_backward_input": forward_backward_input,
|
||||
"model_id": model_id,
|
||||
},
|
||||
training_forward_backward_params.TrainingForwardBackwardParams,
|
||||
),
|
||||
options=make_request_options(
|
||||
extra_headers=extra_headers,
|
||||
extra_query=extra_query,
|
||||
extra_body=extra_body,
|
||||
timeout=timeout,
|
||||
idempotency_key=idempotency_key,
|
||||
),
|
||||
cast_to=UntypedAPIFuture,
|
||||
)
|
||||
|
||||
async def optim_step(
|
||||
self,
|
||||
*,
|
||||
adam_params: training_optim_step_params.AdamParams,
|
||||
model_id: ModelID,
|
||||
type: Literal["optim_step"] | NotGiven = NOT_GIVEN,
|
||||
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
|
||||
# The extra values given here take precedence over values defined on the client or passed to this method.
|
||||
extra_headers: Headers | None = None,
|
||||
extra_query: Query | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
idempotency_key: str | None = None,
|
||||
) -> UntypedAPIFuture:
|
||||
"""
|
||||
Performs an optimization step using AdamW optimizer
|
||||
|
||||
Args:
|
||||
extra_headers: Send extra headers
|
||||
|
||||
extra_query: Add additional query parameters to the request
|
||||
|
||||
extra_body: Add additional JSON properties to the request
|
||||
|
||||
timeout: Override the client-level default timeout for this request, in seconds
|
||||
|
||||
idempotency_key: Specify a custom idempotency key for this request
|
||||
"""
|
||||
return await self._post(
|
||||
"/api/v1/optim_step",
|
||||
body=await async_maybe_transform(
|
||||
{
|
||||
"adam_params": adam_params,
|
||||
"model_id": model_id,
|
||||
"type": type,
|
||||
},
|
||||
training_optim_step_params.TrainingOptimStepParams,
|
||||
),
|
||||
options=make_request_options(
|
||||
extra_headers=extra_headers,
|
||||
extra_query=extra_query,
|
||||
extra_body=extra_body,
|
||||
timeout=timeout,
|
||||
idempotency_key=idempotency_key,
|
||||
),
|
||||
cast_to=UntypedAPIFuture,
|
||||
)
|
||||
|
||||
|
||||
class TrainingResourceWithRawResponse:
|
||||
def __init__(self, training: TrainingResource) -> None:
|
||||
self._training = training
|
||||
|
||||
self.forward = to_raw_response_wrapper(
|
||||
training.forward,
|
||||
)
|
||||
self.forward_backward = to_raw_response_wrapper(
|
||||
training.forward_backward,
|
||||
)
|
||||
self.optim_step = to_raw_response_wrapper(
|
||||
training.optim_step,
|
||||
)
|
||||
|
||||
|
||||
class AsyncTrainingResourceWithRawResponse:
|
||||
def __init__(self, training: AsyncTrainingResource) -> None:
|
||||
self._training = training
|
||||
|
||||
self.forward = async_to_raw_response_wrapper(
|
||||
training.forward,
|
||||
)
|
||||
self.forward_backward = async_to_raw_response_wrapper(
|
||||
training.forward_backward,
|
||||
)
|
||||
self.optim_step = async_to_raw_response_wrapper(
|
||||
training.optim_step,
|
||||
)
|
||||
|
||||
|
||||
class TrainingResourceWithStreamingResponse:
|
||||
def __init__(self, training: TrainingResource) -> None:
|
||||
self._training = training
|
||||
|
||||
self.forward = to_streamed_response_wrapper(
|
||||
training.forward,
|
||||
)
|
||||
self.forward_backward = to_streamed_response_wrapper(
|
||||
training.forward_backward,
|
||||
)
|
||||
self.optim_step = to_streamed_response_wrapper(
|
||||
training.optim_step,
|
||||
)
|
||||
|
||||
|
||||
class AsyncTrainingResourceWithStreamingResponse:
|
||||
def __init__(self, training: AsyncTrainingResource) -> None:
|
||||
self._training = training
|
||||
|
||||
self.forward = async_to_streamed_response_wrapper(
|
||||
training.forward,
|
||||
)
|
||||
self.forward_backward = async_to_streamed_response_wrapper(
|
||||
training.forward_backward,
|
||||
)
|
||||
self.optim_step = async_to_streamed_response_wrapper(
|
||||
training.optim_step,
|
||||
)
|
||||
596
src/tinker/resources/weights.py
Normal file
596
src/tinker/resources/weights.py
Normal file
|
|
@ -0,0 +1,596 @@
|
|||
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing_extensions import Literal
|
||||
|
||||
import httpx
|
||||
|
||||
from ..types import (
|
||||
ModelID,
|
||||
CheckpointsListResponse,
|
||||
weight_load_params,
|
||||
weight_save_params,
|
||||
weight_save_for_sampler_params,
|
||||
)
|
||||
from .._types import NOT_GIVEN, Body, Query, Headers, NotGiven, NoneType
|
||||
from .._utils import maybe_transform, async_maybe_transform
|
||||
from .._compat import cached_property
|
||||
from .._resource import SyncAPIResource, AsyncAPIResource
|
||||
from .._response import (
|
||||
to_raw_response_wrapper,
|
||||
to_streamed_response_wrapper,
|
||||
async_to_raw_response_wrapper,
|
||||
async_to_streamed_response_wrapper,
|
||||
)
|
||||
from .._base_client import make_request_options
|
||||
from ..types.shared.untyped_api_future import UntypedAPIFuture
|
||||
|
||||
__all__ = ["WeightsResource", "AsyncWeightsResource"]
|
||||
|
||||
|
||||
class WeightsResource(SyncAPIResource):
|
||||
@cached_property
|
||||
def with_raw_response(self) -> WeightsResourceWithRawResponse:
|
||||
"""
|
||||
This property can be used as a prefix for any HTTP method call to return
|
||||
the raw response object instead of the parsed content.
|
||||
|
||||
For more information, see https://www.github.com/stainless-sdks/tinker-python#accessing-raw-response-data-eg-headers
|
||||
"""
|
||||
return WeightsResourceWithRawResponse(self)
|
||||
|
||||
@cached_property
|
||||
def with_streaming_response(self) -> WeightsResourceWithStreamingResponse:
|
||||
"""
|
||||
An alternative to `.with_raw_response` that doesn't eagerly read the response body.
|
||||
|
||||
For more information, see https://www.github.com/stainless-sdks/tinker-python#with_streaming_response
|
||||
"""
|
||||
return WeightsResourceWithStreamingResponse(self)
|
||||
|
||||
def load(
|
||||
self,
|
||||
*,
|
||||
model_id: ModelID,
|
||||
path: str,
|
||||
type: Literal["load_weights"] | NotGiven = NOT_GIVEN,
|
||||
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
|
||||
# The extra values given here take precedence over values defined on the client or passed to this method.
|
||||
extra_headers: Headers | None = None,
|
||||
extra_query: Query | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
idempotency_key: str | None = None,
|
||||
) -> UntypedAPIFuture:
|
||||
"""
|
||||
Loads model weights from disk
|
||||
|
||||
Args:
|
||||
path: A tinker URI for model weights at a specific step
|
||||
|
||||
extra_headers: Send extra headers
|
||||
|
||||
extra_query: Add additional query parameters to the request
|
||||
|
||||
extra_body: Add additional JSON properties to the request
|
||||
|
||||
timeout: Override the client-level default timeout for this request, in seconds
|
||||
|
||||
idempotency_key: Specify a custom idempotency key for this request
|
||||
"""
|
||||
return self._post(
|
||||
"/api/v1/load_weights",
|
||||
body=maybe_transform(
|
||||
{
|
||||
"model_id": model_id,
|
||||
"path": path,
|
||||
"type": type,
|
||||
},
|
||||
weight_load_params.WeightLoadParams,
|
||||
),
|
||||
options=make_request_options(
|
||||
extra_headers=extra_headers,
|
||||
extra_query=extra_query,
|
||||
extra_body=extra_body,
|
||||
timeout=timeout,
|
||||
idempotency_key=idempotency_key,
|
||||
),
|
||||
cast_to=UntypedAPIFuture,
|
||||
)
|
||||
|
||||
def save(
|
||||
self,
|
||||
*,
|
||||
model_id: ModelID,
|
||||
path: str | NotGiven = NOT_GIVEN,
|
||||
type: Literal["save_weights"] | NotGiven = NOT_GIVEN,
|
||||
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
|
||||
# The extra values given here take precedence over values defined on the client or passed to this method.
|
||||
extra_headers: Headers | None = None,
|
||||
extra_query: Query | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
idempotency_key: str | None = None,
|
||||
) -> UntypedAPIFuture:
|
||||
"""
|
||||
Saves the current model weights to disk
|
||||
|
||||
Args:
|
||||
path: A file/directory name for the weights
|
||||
|
||||
extra_headers: Send extra headers
|
||||
|
||||
extra_query: Add additional query parameters to the request
|
||||
|
||||
extra_body: Add additional JSON properties to the request
|
||||
|
||||
timeout: Override the client-level default timeout for this request, in seconds
|
||||
|
||||
idempotency_key: Specify a custom idempotency key for this request
|
||||
"""
|
||||
return self._post(
|
||||
"/api/v1/save_weights",
|
||||
body=maybe_transform(
|
||||
{
|
||||
"model_id": model_id,
|
||||
"path": path,
|
||||
"type": type,
|
||||
},
|
||||
weight_save_params.WeightSaveParams,
|
||||
),
|
||||
options=make_request_options(
|
||||
extra_headers=extra_headers,
|
||||
extra_query=extra_query,
|
||||
extra_body=extra_body,
|
||||
timeout=timeout,
|
||||
idempotency_key=idempotency_key,
|
||||
),
|
||||
cast_to=UntypedAPIFuture,
|
||||
)
|
||||
|
||||
def save_for_sampler(
|
||||
self,
|
||||
*,
|
||||
model_id: ModelID,
|
||||
path: str | NotGiven = NOT_GIVEN,
|
||||
type: Literal["save_weights_for_sampler"] | NotGiven = NOT_GIVEN,
|
||||
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
|
||||
# The extra values given here take precedence over values defined on the client or passed to this method.
|
||||
extra_headers: Headers | None = None,
|
||||
extra_query: Query | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
idempotency_key: str | None = None,
|
||||
) -> UntypedAPIFuture:
|
||||
"""
|
||||
Saves weights in a format compatible with sampling/inference servers
|
||||
|
||||
Args:
|
||||
path: A file/directory name for the weights
|
||||
|
||||
extra_headers: Send extra headers
|
||||
|
||||
extra_query: Add additional query parameters to the request
|
||||
|
||||
extra_body: Add additional JSON properties to the request
|
||||
|
||||
timeout: Override the client-level default timeout for this request, in seconds
|
||||
|
||||
idempotency_key: Specify a custom idempotency key for this request
|
||||
"""
|
||||
return self._post(
|
||||
"/api/v1/save_weights_for_sampler",
|
||||
body=maybe_transform(
|
||||
{
|
||||
"model_id": model_id,
|
||||
"path": path,
|
||||
"type": type,
|
||||
},
|
||||
weight_save_for_sampler_params.WeightSaveForSamplerParams,
|
||||
),
|
||||
options=make_request_options(
|
||||
extra_headers=extra_headers,
|
||||
extra_query=extra_query,
|
||||
extra_body=extra_body,
|
||||
timeout=timeout,
|
||||
idempotency_key=idempotency_key,
|
||||
),
|
||||
cast_to=UntypedAPIFuture,
|
||||
)
|
||||
|
||||
def list(
|
||||
self,
|
||||
model_id: ModelID,
|
||||
*,
|
||||
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
|
||||
# The extra values given here take precedence over values defined on the client or passed to this method.
|
||||
extra_headers: Headers | None = None,
|
||||
extra_query: Query | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> CheckpointsListResponse:
|
||||
"""
|
||||
Lists available model checkpoints (both training and sampler)
|
||||
|
||||
Args:
|
||||
model_id: The model ID to list checkpoints for
|
||||
|
||||
extra_headers: Send extra headers
|
||||
|
||||
extra_query: Add additional query parameters to the request
|
||||
|
||||
extra_body: Add additional JSON properties to the request
|
||||
|
||||
timeout: Override the client-level default timeout for this request, in seconds
|
||||
"""
|
||||
if not model_id:
|
||||
raise ValueError(f"Expected a non-empty value for `model_id` but received {model_id!r}")
|
||||
return self._get(
|
||||
f"/api/v1/models/{model_id}/checkpoints",
|
||||
options=make_request_options(
|
||||
extra_headers=extra_headers,
|
||||
extra_query=extra_query,
|
||||
extra_body=extra_body,
|
||||
timeout=timeout,
|
||||
),
|
||||
cast_to=CheckpointsListResponse,
|
||||
)
|
||||
|
||||
def delete_checkpoint(
|
||||
self,
|
||||
*,
|
||||
model_id: ModelID,
|
||||
checkpoint_id: str,
|
||||
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
|
||||
# The extra values given here take precedence over values defined on the client or passed to this method.
|
||||
extra_headers: Headers | None = None,
|
||||
extra_query: Query | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> None:
|
||||
"""Delete a checkpoint for the given training run."""
|
||||
if not model_id:
|
||||
raise ValueError(f"Expected a non-empty value for `model_id` but received {model_id!r}")
|
||||
if not checkpoint_id:
|
||||
raise ValueError(
|
||||
f"Expected a non-empty value for `checkpoint_id` but received {checkpoint_id!r}"
|
||||
)
|
||||
|
||||
self._delete(
|
||||
f"/api/v1/training_runs/{model_id}/checkpoints/{checkpoint_id}",
|
||||
options=make_request_options(
|
||||
extra_headers=extra_headers,
|
||||
extra_query=extra_query,
|
||||
extra_body=extra_body,
|
||||
timeout=timeout,
|
||||
),
|
||||
cast_to=NoneType,
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
class AsyncWeightsResource(AsyncAPIResource):
|
||||
@cached_property
|
||||
def with_raw_response(self) -> AsyncWeightsResourceWithRawResponse:
|
||||
"""
|
||||
This property can be used as a prefix for any HTTP method call to return
|
||||
the raw response object instead of the parsed content.
|
||||
|
||||
For more information, see https://www.github.com/stainless-sdks/tinker-python#accessing-raw-response-data-eg-headers
|
||||
"""
|
||||
return AsyncWeightsResourceWithRawResponse(self)
|
||||
|
||||
@cached_property
|
||||
def with_streaming_response(self) -> AsyncWeightsResourceWithStreamingResponse:
|
||||
"""
|
||||
An alternative to `.with_raw_response` that doesn't eagerly read the response body.
|
||||
|
||||
For more information, see https://www.github.com/stainless-sdks/tinker-python#with_streaming_response
|
||||
"""
|
||||
return AsyncWeightsResourceWithStreamingResponse(self)
|
||||
|
||||
async def load(
|
||||
self,
|
||||
*,
|
||||
model_id: ModelID,
|
||||
path: str,
|
||||
type: Literal["load_weights"] | NotGiven = NOT_GIVEN,
|
||||
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
|
||||
# The extra values given here take precedence over values defined on the client or passed to this method.
|
||||
extra_headers: Headers | None = None,
|
||||
extra_query: Query | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
idempotency_key: str | None = None,
|
||||
) -> UntypedAPIFuture:
|
||||
"""
|
||||
Loads model weights from disk
|
||||
|
||||
Args:
|
||||
path: A tinker URI for model weights at a specific step
|
||||
|
||||
extra_headers: Send extra headers
|
||||
|
||||
extra_query: Add additional query parameters to the request
|
||||
|
||||
extra_body: Add additional JSON properties to the request
|
||||
|
||||
timeout: Override the client-level default timeout for this request, in seconds
|
||||
|
||||
idempotency_key: Specify a custom idempotency key for this request
|
||||
"""
|
||||
return await self._post(
|
||||
"/api/v1/load_weights",
|
||||
body=await async_maybe_transform(
|
||||
{
|
||||
"model_id": model_id,
|
||||
"path": path,
|
||||
"type": type,
|
||||
},
|
||||
weight_load_params.WeightLoadParams,
|
||||
),
|
||||
options=make_request_options(
|
||||
extra_headers=extra_headers,
|
||||
extra_query=extra_query,
|
||||
extra_body=extra_body,
|
||||
timeout=timeout,
|
||||
idempotency_key=idempotency_key,
|
||||
),
|
||||
cast_to=UntypedAPIFuture,
|
||||
)
|
||||
|
||||
async def save(
|
||||
self,
|
||||
*,
|
||||
model_id: ModelID,
|
||||
path: str | NotGiven = NOT_GIVEN,
|
||||
type: Literal["save_weights"] | NotGiven = NOT_GIVEN,
|
||||
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
|
||||
# The extra values given here take precedence over values defined on the client or passed to this method.
|
||||
extra_headers: Headers | None = None,
|
||||
extra_query: Query | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
idempotency_key: str | None = None,
|
||||
) -> UntypedAPIFuture:
|
||||
"""
|
||||
Saves the current model weights to disk
|
||||
|
||||
Args:
|
||||
path: A file/directory name for the weights
|
||||
|
||||
extra_headers: Send extra headers
|
||||
|
||||
extra_query: Add additional query parameters to the request
|
||||
|
||||
extra_body: Add additional JSON properties to the request
|
||||
|
||||
timeout: Override the client-level default timeout for this request, in seconds
|
||||
|
||||
idempotency_key: Specify a custom idempotency key for this request
|
||||
"""
|
||||
return await self._post(
|
||||
"/api/v1/save_weights",
|
||||
body=await async_maybe_transform(
|
||||
{
|
||||
"model_id": model_id,
|
||||
"path": path,
|
||||
"type": type,
|
||||
},
|
||||
weight_save_params.WeightSaveParams,
|
||||
),
|
||||
options=make_request_options(
|
||||
extra_headers=extra_headers,
|
||||
extra_query=extra_query,
|
||||
extra_body=extra_body,
|
||||
timeout=timeout,
|
||||
idempotency_key=idempotency_key,
|
||||
),
|
||||
cast_to=UntypedAPIFuture,
|
||||
)
|
||||
|
||||
async def save_for_sampler(
|
||||
self,
|
||||
*,
|
||||
model_id: ModelID,
|
||||
path: str | NotGiven = NOT_GIVEN,
|
||||
type: Literal["save_weights_for_sampler"] | NotGiven = NOT_GIVEN,
|
||||
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
|
||||
# The extra values given here take precedence over values defined on the client or passed to this method.
|
||||
extra_headers: Headers | None = None,
|
||||
extra_query: Query | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
idempotency_key: str | None = None,
|
||||
) -> UntypedAPIFuture:
|
||||
"""
|
||||
Saves weights in a format compatible with sampling/inference servers
|
||||
|
||||
Args:
|
||||
path: A file/directory name for the weights
|
||||
|
||||
extra_headers: Send extra headers
|
||||
|
||||
extra_query: Add additional query parameters to the request
|
||||
|
||||
extra_body: Add additional JSON properties to the request
|
||||
|
||||
timeout: Override the client-level default timeout for this request, in seconds
|
||||
|
||||
idempotency_key: Specify a custom idempotency key for this request
|
||||
"""
|
||||
return await self._post(
|
||||
"/api/v1/save_weights_for_sampler",
|
||||
body=await async_maybe_transform(
|
||||
{
|
||||
"model_id": model_id,
|
||||
"path": path,
|
||||
"type": type,
|
||||
},
|
||||
weight_save_for_sampler_params.WeightSaveForSamplerParams,
|
||||
),
|
||||
options=make_request_options(
|
||||
extra_headers=extra_headers,
|
||||
extra_query=extra_query,
|
||||
extra_body=extra_body,
|
||||
timeout=timeout,
|
||||
idempotency_key=idempotency_key,
|
||||
),
|
||||
cast_to=UntypedAPIFuture,
|
||||
)
|
||||
|
||||
async def list(
|
||||
self,
|
||||
model_id: ModelID,
|
||||
*,
|
||||
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
|
||||
# The extra values given here take precedence over values defined on the client or passed to this method.
|
||||
extra_headers: Headers | None = None,
|
||||
extra_query: Query | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> CheckpointsListResponse:
|
||||
"""
|
||||
Lists available model checkpoints (both training and sampler)
|
||||
|
||||
Args:
|
||||
model_id: The model ID to list checkpoints for
|
||||
|
||||
extra_headers: Send extra headers
|
||||
|
||||
extra_query: Add additional query parameters to the request
|
||||
|
||||
extra_body: Add additional JSON properties to the request
|
||||
|
||||
timeout: Override the client-level default timeout for this request, in seconds
|
||||
"""
|
||||
if not model_id:
|
||||
raise ValueError(f"Expected a non-empty value for `model_id` but received {model_id!r}")
|
||||
return await self._get(
|
||||
f"/api/v1/training_runs/{model_id}/checkpoints",
|
||||
options=make_request_options(
|
||||
extra_headers=extra_headers,
|
||||
extra_query=extra_query,
|
||||
extra_body=extra_body,
|
||||
timeout=timeout,
|
||||
),
|
||||
cast_to=CheckpointsListResponse,
|
||||
)
|
||||
|
||||
async def delete_checkpoint(
|
||||
self,
|
||||
*,
|
||||
model_id: ModelID,
|
||||
checkpoint_id: str,
|
||||
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
|
||||
# The extra values given here take precedence over values defined on the client or passed to this method.
|
||||
extra_headers: Headers | None = None,
|
||||
extra_query: Query | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> None:
|
||||
"""Delete a checkpoint for the given training run."""
|
||||
if not model_id:
|
||||
raise ValueError(f"Expected a non-empty value for `model_id` but received {model_id!r}")
|
||||
if not checkpoint_id:
|
||||
raise ValueError(
|
||||
f"Expected a non-empty value for `checkpoint_id` but received {checkpoint_id!r}"
|
||||
)
|
||||
|
||||
await self._delete(
|
||||
f"/api/v1/training_runs/{model_id}/checkpoints/{checkpoint_id}",
|
||||
options=make_request_options(
|
||||
extra_headers=extra_headers,
|
||||
extra_query=extra_query,
|
||||
extra_body=extra_body,
|
||||
timeout=timeout,
|
||||
),
|
||||
cast_to=NoneType,
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class WeightsResourceWithRawResponse:
|
||||
def __init__(self, weights: WeightsResource) -> None:
|
||||
self._weights = weights
|
||||
|
||||
self.load = to_raw_response_wrapper(
|
||||
weights.load,
|
||||
)
|
||||
self.save = to_raw_response_wrapper(
|
||||
weights.save,
|
||||
)
|
||||
self.save_for_sampler = to_raw_response_wrapper(
|
||||
weights.save_for_sampler,
|
||||
)
|
||||
self.list = to_raw_response_wrapper(
|
||||
weights.list,
|
||||
)
|
||||
self.delete_checkpoint = to_raw_response_wrapper(
|
||||
weights.delete_checkpoint,
|
||||
)
|
||||
|
||||
|
||||
class AsyncWeightsResourceWithRawResponse:
|
||||
def __init__(self, weights: AsyncWeightsResource) -> None:
|
||||
self._weights = weights
|
||||
|
||||
self.load = async_to_raw_response_wrapper(
|
||||
weights.load,
|
||||
)
|
||||
self.save = async_to_raw_response_wrapper(
|
||||
weights.save,
|
||||
)
|
||||
self.save_for_sampler = async_to_raw_response_wrapper(
|
||||
weights.save_for_sampler,
|
||||
)
|
||||
self.list = async_to_raw_response_wrapper(
|
||||
weights.list,
|
||||
)
|
||||
self.delete_checkpoint = async_to_raw_response_wrapper(
|
||||
weights.delete_checkpoint,
|
||||
)
|
||||
|
||||
|
||||
class WeightsResourceWithStreamingResponse:
|
||||
def __init__(self, weights: WeightsResource) -> None:
|
||||
self._weights = weights
|
||||
|
||||
self.load = to_streamed_response_wrapper(
|
||||
weights.load,
|
||||
)
|
||||
self.save = to_streamed_response_wrapper(
|
||||
weights.save,
|
||||
)
|
||||
self.save_for_sampler = to_streamed_response_wrapper(
|
||||
weights.save_for_sampler,
|
||||
)
|
||||
self.list = to_streamed_response_wrapper(
|
||||
weights.list,
|
||||
)
|
||||
self.delete_checkpoint = to_streamed_response_wrapper(
|
||||
weights.delete_checkpoint,
|
||||
)
|
||||
|
||||
|
||||
class AsyncWeightsResourceWithStreamingResponse:
|
||||
def __init__(self, weights: AsyncWeightsResource) -> None:
|
||||
self._weights = weights
|
||||
|
||||
self.load = async_to_streamed_response_wrapper(
|
||||
weights.load,
|
||||
)
|
||||
self.save = async_to_streamed_response_wrapper(
|
||||
weights.save,
|
||||
)
|
||||
self.save_for_sampler = async_to_streamed_response_wrapper(
|
||||
weights.save_for_sampler,
|
||||
)
|
||||
self.list = async_to_streamed_response_wrapper(
|
||||
weights.list,
|
||||
)
|
||||
self.delete_checkpoint = async_to_streamed_response_wrapper(
|
||||
weights.delete_checkpoint,
|
||||
)
|
||||
90
src/tinker/types/__init__.py
Normal file
90
src/tinker/types/__init__.py
Normal file
|
|
@ -0,0 +1,90 @@
|
|||
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
# There's an underscore in front of *Param classes (TypedDict) because they shouldn't be used.
|
||||
from .datum import Datum as Datum
|
||||
from .shared import UntypedAPIFuture as UntypedAPIFuture
|
||||
from .model_id import ModelID as ModelID
|
||||
from .severity import Severity as Severity
|
||||
from .event_type import EventType as EventType
|
||||
from .request_id import RequestID as RequestID
|
||||
from .datum_param import DatumParam as _DatumParam
|
||||
from .lora_config import LoraConfig as LoraConfig
|
||||
from .model_input import ModelInput as ModelInput
|
||||
from .stop_reason import StopReason as StopReason
|
||||
from .tensor_data import TensorData as TensorData
|
||||
from .loss_fn_type import LossFnType as LossFnType
|
||||
from .tensor_dtype import TensorDtype as TensorDtype
|
||||
from .error_response import ErrorResponse as ErrorResponse
|
||||
from .loss_fn_inputs import LossFnInputs as LossFnInputs
|
||||
from .loss_fn_output import LossFnOutput as LossFnOutput
|
||||
from .sample_request import SampleRequest as SampleRequest
|
||||
from .health_response import HealthResponse as HealthResponse
|
||||
from .sample_response import SampleResponse as SampleResponse
|
||||
from .sampling_params import SamplingParams as SamplingParams
|
||||
from .telemetry_batch import TelemetryBatch as TelemetryBatch
|
||||
from .telemetry_event import TelemetryEvent as TelemetryEvent
|
||||
from .get_info_request import GetInfoRequest as GetInfoRequest
|
||||
from .sampled_sequence import SampledSequence as SampledSequence
|
||||
from .get_info_response import GetInfoResponse as GetInfoResponse
|
||||
from .get_info_response import ModelData as ModelData
|
||||
from .lora_config_param import LoraConfigParam as _LoraConfigParam
|
||||
from .model_input_chunk import ModelInputChunk as ModelInputChunk
|
||||
from .model_input_param import ModelInputParam as _ModelInputParam
|
||||
from .tensor_data_param import TensorDataParam as _TensorDataParam
|
||||
from .encoded_text_chunk import EncodedTextChunk as EncodedTextChunk
|
||||
from .optim_step_request import OptimStepRequest as OptimStepRequest
|
||||
from .checkpoint import Checkpoint as Checkpoint, CheckpointType as CheckpointType, ParsedCheckpointTinkerPath as ParsedCheckpointTinkerPath
|
||||
from .weight_load_params import WeightLoadParams as _WeightLoadParams
|
||||
from .weight_save_params import WeightSaveParams as _WeightSaveParams
|
||||
from .checkpoints_list_response import CheckpointsListResponse as CheckpointsListResponse
|
||||
from .cursor import Cursor as Cursor
|
||||
from .training_run_ids_response import TrainingRunIdsResponse as TrainingRunIdsResponse
|
||||
from .training_runs_response import TrainingRunsResponse as TrainingRunsResponse
|
||||
from .forward_backward_input_param import ForwardBackwardInputParam as _ForwardBackwardInputParam
|
||||
from .forward_backward_input import ForwardBackwardInput as ForwardBackwardInput
|
||||
from .forward_backward_output import ForwardBackwardOutput as ForwardBackwardOutput
|
||||
from .model_create_params import ModelCreateParams as _ModelCreateParams
|
||||
from .model_unload_params import ModelUnloadParams as _ModelUnloadParams
|
||||
from .session_end_event import SessionEndEvent as SessionEndEvent
|
||||
from .telemetry_response import TelemetryResponse as TelemetryResponse
|
||||
from .try_again_response import TryAgainResponse as TryAgainResponse
|
||||
from .optim_step_response import OptimStepResponse as OptimStepResponse
|
||||
from .session_start_event import SessionStartEvent as SessionStartEvent
|
||||
from .create_model_request import CreateModelRequest as CreateModelRequest
|
||||
from .load_weights_request import LoadWeightsRequest as LoadWeightsRequest
|
||||
from .loss_fn_inputs_param import LossFnInputsParam as _LossFnInputsParam
|
||||
from .save_weights_request import SaveWeightsRequest as SaveWeightsRequest
|
||||
from .unload_model_request import UnloadModelRequest as UnloadModelRequest
|
||||
from .create_model_response import CreateModelResponse as CreateModelResponse
|
||||
from .load_weights_response import LoadWeightsResponse as LoadWeightsResponse
|
||||
from .model_get_info_params import ModelGetInfoParams as _ModelGetInfoParams
|
||||
from .save_weights_response import SaveWeightsResponse as SaveWeightsResponse
|
||||
from .telemetry_event_param import TelemetryEventParam as TelemetryEventParam
|
||||
from .telemetry_send_params import TelemetrySendParams as TelemetrySendParams
|
||||
from .unload_model_response import UnloadModelResponse as UnloadModelResponse
|
||||
from .future_retrieve_params import FutureRetrieveParams as _FutureRetrieveParams
|
||||
from .model_input_chunk_param import ModelInputChunkParam as _ModelInputChunkParam
|
||||
from .training_forward_params import TrainingForwardParams as _TrainingForwardParams
|
||||
from .encoded_text_chunk_param import EncodedTextChunkParam as _EncodedTextChunkParam
|
||||
from .sampling_params_param import SamplingParamsParam as _SamplingParamsParam
|
||||
from .sampling_sample_params import SamplingSampleParams as _SamplingSampleParams
|
||||
from .sampling_asample_params import SamplingAsampleParams as _SamplingAsampleParams
|
||||
from .future_retrieve_response import FutureRetrieveResponse as FutureRetrieveResponse
|
||||
from .compute_logprobs_response import ComputeLogprobsResponse as ComputeLogprobsResponse
|
||||
from .image_asset_pointer_chunk import ImageAssetPointerChunk as ImageAssetPointerChunk
|
||||
from .training_optim_step_params import TrainingOptimStepParams as _TrainingOptimStepParams
|
||||
from .weight_save_for_sampler_params import WeightSaveForSamplerParams as _WeightSaveForSamplerParams
|
||||
from .image_asset_pointer_chunk_param import ImageAssetPointerChunkParam as _ImageAssetPointerChunkParam
|
||||
from .session_end_event_param import SessionEndEventParam as _SessionEndEventParam
|
||||
from .session_start_event_param import SessionStartEventParam as _SessionStartEventParam
|
||||
from .unhandled_exception_event import UnhandledExceptionEvent as UnhandledExceptionEvent
|
||||
from .unhandled_exception_event_param import UnhandledExceptionEventParam as _UnhandledExceptionEventParam
|
||||
from .get_server_capabilities_response import GetServerCapabilitiesResponse as GetServerCapabilitiesResponse
|
||||
from .save_weights_for_sampler_request import SaveWeightsForSamplerRequest as SaveWeightsForSamplerRequest
|
||||
from .get_server_capabilities_response import SupportedModel as SupportedModel
|
||||
from .training_forward_backward_params import TrainingForwardBackwardParams as _TrainingForwardBackwardParams
|
||||
from .save_weights_for_sampler_response import SaveWeightsForSamplerResponse as SaveWeightsForSamplerResponse
|
||||
from .optim_step_request import AdamParams as AdamParams
|
||||
from .training_run import TrainingRun as TrainingRun
|
||||
54
src/tinker/types/checkpoint.py
Normal file
54
src/tinker/types/checkpoint.py
Normal file
|
|
@ -0,0 +1,54 @@
|
|||
from datetime import datetime
|
||||
from typing import Literal
|
||||
|
||||
from .._models import BaseModel
|
||||
|
||||
__all__ = ["Checkpoint", "CheckpointType"]
|
||||
|
||||
CheckpointType = Literal["training", "sampler"]
|
||||
|
||||
|
||||
class Checkpoint(BaseModel):
|
||||
checkpoint_id: str
|
||||
"""The checkpoint ID"""
|
||||
|
||||
checkpoint_type: CheckpointType
|
||||
"""The type of checkpoint (training or sampler)"""
|
||||
|
||||
time: datetime
|
||||
"""The time when the checkpoint was created"""
|
||||
|
||||
tinker_path: str
|
||||
"""The tinker path to the checkpoint"""
|
||||
|
||||
|
||||
class ParsedCheckpointTinkerPath(BaseModel):
|
||||
tinker_path: str
|
||||
"""The tinker path to the checkpoint"""
|
||||
|
||||
training_run_id: str
|
||||
"""The training run ID"""
|
||||
|
||||
checkpoint_type: CheckpointType
|
||||
"""The type of checkpoint (training or sampler)"""
|
||||
|
||||
checkpoint_id: str
|
||||
"""The checkpoint ID"""
|
||||
|
||||
@classmethod
|
||||
def from_tinker_path(cls, tinker_path: str) -> "ParsedCheckpointTinkerPath":
|
||||
"""Parse a tinker path to an instance of ParsedCheckpointTinkerPath"""
|
||||
if not tinker_path.startswith("tinker://"):
|
||||
raise ValueError(f"Invalid tinker path: {tinker_path}")
|
||||
parts = tinker_path[9:].split("/")
|
||||
if len(parts) != 3:
|
||||
raise ValueError(f"Invalid tinker path: {tinker_path}")
|
||||
if parts[1] not in ["weights", "sampler_weights"]:
|
||||
raise ValueError(f"Invalid tinker path: {tinker_path}")
|
||||
checkpoint_type = "training" if parts[1] == "weights" else "sampler"
|
||||
return cls(
|
||||
tinker_path=tinker_path,
|
||||
training_run_id=parts[0],
|
||||
checkpoint_type=checkpoint_type,
|
||||
checkpoint_id="/".join(parts[1:]),
|
||||
)
|
||||
9
src/tinker/types/checkpoints_list_response.py
Normal file
9
src/tinker/types/checkpoints_list_response.py
Normal file
|
|
@ -0,0 +1,9 @@
|
|||
from .._models import BaseModel
|
||||
from .checkpoint import Checkpoint
|
||||
|
||||
__all__ = ["CheckpointsListResponse"]
|
||||
|
||||
|
||||
class CheckpointsListResponse(BaseModel):
|
||||
checkpoints: list[Checkpoint]
|
||||
"""List of available model checkpoints for the model"""
|
||||
14
src/tinker/types/compute_logprobs_response.py
Normal file
14
src/tinker/types/compute_logprobs_response.py
Normal file
|
|
@ -0,0 +1,14 @@
|
|||
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
||||
|
||||
from typing import Optional
|
||||
from typing_extensions import Literal, Sequence
|
||||
|
||||
from .._models import BaseModel
|
||||
|
||||
__all__ = ["ComputeLogprobsResponse"]
|
||||
|
||||
|
||||
class ComputeLogprobsResponse(BaseModel):
|
||||
logprobs: Sequence[Optional[float]]
|
||||
|
||||
type: Literal["compute_logprobs"] = "compute_logprobs"
|
||||
17
src/tinker/types/create_model_request.py
Normal file
17
src/tinker/types/create_model_request.py
Normal file
|
|
@ -0,0 +1,17 @@
|
|||
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
||||
|
||||
from typing import Optional
|
||||
from typing_extensions import Literal
|
||||
|
||||
from .._models import StrictBase
|
||||
from .lora_config import LoraConfig
|
||||
|
||||
__all__ = ["CreateModelRequest"]
|
||||
|
||||
|
||||
class CreateModelRequest(StrictBase):
|
||||
base_model: str
|
||||
|
||||
lora_config: Optional[LoraConfig] = None
|
||||
|
||||
type: Literal["create_model"] = "create_model"
|
||||
15
src/tinker/types/create_model_response.py
Normal file
15
src/tinker/types/create_model_response.py
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
||||
|
||||
from typing_extensions import Literal
|
||||
|
||||
from .._compat import PYDANTIC_V2, ConfigDict
|
||||
from .._models import BaseModel
|
||||
from .model_id import ModelID
|
||||
|
||||
__all__ = ["CreateModelResponse"]
|
||||
|
||||
|
||||
class CreateModelResponse(BaseModel):
|
||||
model_id: ModelID
|
||||
|
||||
type: Literal["create_model"] = "create_model"
|
||||
14
src/tinker/types/cursor.py
Normal file
14
src/tinker/types/cursor.py
Normal file
|
|
@ -0,0 +1,14 @@
|
|||
from .._models import BaseModel
|
||||
|
||||
__all__ = ["Cursor"]
|
||||
|
||||
|
||||
class Cursor(BaseModel):
|
||||
offset: int
|
||||
"""The offset used for pagination"""
|
||||
|
||||
limit: int
|
||||
"""The maximum number of items requested"""
|
||||
|
||||
total_count: int
|
||||
"""The total number of items available"""
|
||||
66
src/tinker/types/datum.py
Normal file
66
src/tinker/types/datum.py
Normal file
|
|
@ -0,0 +1,66 @@
|
|||
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
||||
|
||||
from typing import Any, TYPE_CHECKING
|
||||
|
||||
from pydantic import model_validator
|
||||
|
||||
from .._models import StrictBase
|
||||
from .loss_fn_inputs import LossFnInputs
|
||||
from .model_input import ModelInput
|
||||
from .tensor_data import TensorData
|
||||
|
||||
try:
|
||||
import torch # type: ignore[import-not-found]
|
||||
|
||||
_HAVE_TORCH = True
|
||||
except ImportError:
|
||||
_HAVE_TORCH = False
|
||||
|
||||
import numpy as np
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import torch
|
||||
|
||||
__all__ = ["Datum"]
|
||||
|
||||
|
||||
class Datum(StrictBase):
|
||||
loss_fn_inputs: LossFnInputs
|
||||
"""Dictionary mapping field names to tensor data"""
|
||||
|
||||
model_input: ModelInput
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def convert_tensors(cls, data: Any) -> Any:
|
||||
"""Convert torch.Tensor and numpy arrays to TensorData in loss_fn_inputs during construction."""
|
||||
if isinstance(data, dict) and "loss_fn_inputs" in data:
|
||||
loss_fn_inputs = data["loss_fn_inputs"]
|
||||
if isinstance(loss_fn_inputs, dict):
|
||||
converted_inputs = {}
|
||||
for key, value in loss_fn_inputs.items():
|
||||
converted_inputs[key] = cls._maybe_convert_array(key, value)
|
||||
data = dict(data) # Make a copy
|
||||
data["loss_fn_inputs"] = converted_inputs
|
||||
return data
|
||||
|
||||
@classmethod
|
||||
def _maybe_convert_array(cls, key: str, value: Any) -> Any:
|
||||
"""Convert torch.Tensor, numpy array, or 1-D list to TensorData if needed."""
|
||||
if _HAVE_TORCH and isinstance(value, torch.Tensor):
|
||||
return TensorData.from_torch(value)
|
||||
elif isinstance(value, np.ndarray):
|
||||
return TensorData.from_numpy(value)
|
||||
elif isinstance(value, list):
|
||||
# assume it's 1d and infer the dtype from the key
|
||||
return TensorData(data=value, dtype=_key_to_type[key], shape=[len(value)])
|
||||
else:
|
||||
return value
|
||||
|
||||
|
||||
_key_to_type = {
|
||||
"target_tokens": "int64",
|
||||
"weights": "float32",
|
||||
"advantages": "float32",
|
||||
"logprobs": "float32",
|
||||
}
|
||||
17
src/tinker/types/datum_param.py
Normal file
17
src/tinker/types/datum_param.py
Normal file
|
|
@ -0,0 +1,17 @@
|
|||
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing_extensions import Required, TypedDict
|
||||
|
||||
from .model_input_param import ModelInputParam
|
||||
from .loss_fn_inputs_param import LossFnInputsParam
|
||||
|
||||
__all__ = ["DatumParam"]
|
||||
|
||||
|
||||
class DatumParam(TypedDict, total=False):
|
||||
loss_fn_inputs: Required[LossFnInputsParam]
|
||||
"""Dictionary mapping field names to tensor data"""
|
||||
|
||||
model_input: Required[ModelInputParam]
|
||||
20
src/tinker/types/encoded_text_chunk.py
Normal file
20
src/tinker/types/encoded_text_chunk.py
Normal file
|
|
@ -0,0 +1,20 @@
|
|||
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
||||
|
||||
from typing import Sequence
|
||||
|
||||
from typing_extensions import Literal
|
||||
|
||||
from .._models import StrictBase
|
||||
|
||||
__all__ = ["EncodedTextChunk"]
|
||||
|
||||
|
||||
class EncodedTextChunk(StrictBase):
|
||||
tokens: Sequence[int]
|
||||
"""Array of token IDs"""
|
||||
|
||||
type: Literal["encoded_text"] = "encoded_text"
|
||||
|
||||
@property
|
||||
def length(self) -> int:
|
||||
return len(self.tokens)
|
||||
15
src/tinker/types/encoded_text_chunk_param.py
Normal file
15
src/tinker/types/encoded_text_chunk_param.py
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Iterable
|
||||
from typing_extensions import Literal, Required, TypedDict
|
||||
|
||||
__all__ = ["EncodedTextChunkParam"]
|
||||
|
||||
|
||||
class EncodedTextChunkParam(TypedDict, total=False):
|
||||
tokens: Required[Iterable[int]]
|
||||
"""Array of token IDs"""
|
||||
|
||||
type: Required[Literal["encoded_text"]]
|
||||
18
src/tinker/types/error_response.py
Normal file
18
src/tinker/types/error_response.py
Normal file
|
|
@ -0,0 +1,18 @@
|
|||
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
||||
|
||||
from typing import Dict, Optional
|
||||
|
||||
from .._models import BaseModel
|
||||
|
||||
__all__ = ["ErrorResponse"]
|
||||
|
||||
|
||||
class ErrorResponse(BaseModel):
|
||||
error: str
|
||||
"""Error code"""
|
||||
|
||||
message: str
|
||||
"""Human-readable error message"""
|
||||
|
||||
details: Optional[Dict[str, object]] = None
|
||||
"""Additional error details"""
|
||||
9
src/tinker/types/event_type.py
Normal file
9
src/tinker/types/event_type.py
Normal file
|
|
@ -0,0 +1,9 @@
|
|||
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
||||
|
||||
from typing_extensions import Literal, TypeAlias
|
||||
|
||||
__all__ = ["EventType"]
|
||||
|
||||
EventType: TypeAlias = Literal[
|
||||
"SESSION_START", "SESSION_END", "UNHANDLED_EXCEPTION", "GENERIC_EVENT"
|
||||
]
|
||||
17
src/tinker/types/forward_backward_input.py
Normal file
17
src/tinker/types/forward_backward_input.py
Normal file
|
|
@ -0,0 +1,17 @@
|
|||
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
||||
|
||||
from typing import List
|
||||
|
||||
from .datum import Datum
|
||||
from .._models import StrictBase
|
||||
from .loss_fn_type import LossFnType
|
||||
|
||||
__all__ = ["ForwardBackwardInput"]
|
||||
|
||||
|
||||
class ForwardBackwardInput(StrictBase):
|
||||
data: List[Datum]
|
||||
"""Array of input data for the forward/backward pass"""
|
||||
|
||||
loss_fn: LossFnType
|
||||
"""Fully qualified function path for the loss function"""
|
||||
19
src/tinker/types/forward_backward_input_param.py
Normal file
19
src/tinker/types/forward_backward_input_param.py
Normal file
|
|
@ -0,0 +1,19 @@
|
|||
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Iterable
|
||||
from typing_extensions import Required, TypedDict
|
||||
|
||||
from .datum_param import DatumParam
|
||||
from .loss_fn_type import LossFnType
|
||||
|
||||
__all__ = ["ForwardBackwardInputParam"]
|
||||
|
||||
|
||||
class ForwardBackwardInputParam(TypedDict, total=False):
|
||||
data: Required[Iterable[DatumParam]]
|
||||
"""Array of input data for the forward/backward pass"""
|
||||
|
||||
loss_fn: Required[LossFnType]
|
||||
"""Fully qualified function path for the loss function"""
|
||||
19
src/tinker/types/forward_backward_output.py
Normal file
19
src/tinker/types/forward_backward_output.py
Normal file
|
|
@ -0,0 +1,19 @@
|
|||
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
||||
|
||||
from typing import Dict, List
|
||||
|
||||
from .._models import BaseModel
|
||||
from .loss_fn_output import LossFnOutput
|
||||
|
||||
__all__ = ["ForwardBackwardOutput"]
|
||||
|
||||
|
||||
class ForwardBackwardOutput(BaseModel):
|
||||
loss_fn_output_type: str
|
||||
"""The type of the ForwardBackward output. Can be one of [...] TODO"""
|
||||
|
||||
loss_fn_outputs: List[LossFnOutput]
|
||||
"""Dictionary mapping field names to tensor data"""
|
||||
|
||||
metrics: Dict[str, float]
|
||||
"""Training metrics as key-value pairs"""
|
||||
16
src/tinker/types/future_retrieve_params.py
Normal file
16
src/tinker/types/future_retrieve_params.py
Normal file
|
|
@ -0,0 +1,16 @@
|
|||
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing_extensions import Required, TypedDict
|
||||
|
||||
from .model_id import ModelID
|
||||
from .request_id import RequestID
|
||||
|
||||
__all__ = ["FutureRetrieveParams"]
|
||||
|
||||
|
||||
class FutureRetrieveParams(TypedDict, total=False):
|
||||
request_id: Required[RequestID]
|
||||
|
||||
model_id: ModelID
|
||||
26
src/tinker/types/future_retrieve_response.py
Normal file
26
src/tinker/types/future_retrieve_response.py
Normal file
|
|
@ -0,0 +1,26 @@
|
|||
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
||||
|
||||
from typing import Union
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
from .try_again_response import TryAgainResponse
|
||||
from .optim_step_response import OptimStepResponse
|
||||
from .create_model_response import CreateModelResponse
|
||||
from .load_weights_response import LoadWeightsResponse
|
||||
from .save_weights_response import SaveWeightsResponse
|
||||
from .unload_model_response import UnloadModelResponse
|
||||
from .forward_backward_output import ForwardBackwardOutput
|
||||
from .save_weights_for_sampler_response import SaveWeightsForSamplerResponse
|
||||
|
||||
__all__ = ["FutureRetrieveResponse"]
|
||||
|
||||
FutureRetrieveResponse: TypeAlias = Union[
|
||||
TryAgainResponse,
|
||||
ForwardBackwardOutput,
|
||||
OptimStepResponse,
|
||||
SaveWeightsResponse,
|
||||
LoadWeightsResponse,
|
||||
SaveWeightsForSamplerResponse,
|
||||
CreateModelResponse,
|
||||
UnloadModelResponse,
|
||||
]
|
||||
30
src/tinker/types/generic_event.py
Normal file
30
src/tinker/types/generic_event.py
Normal file
|
|
@ -0,0 +1,30 @@
|
|||
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Dict
|
||||
|
||||
from .._models import BaseModel
|
||||
from .event_type import EventType
|
||||
from .severity import Severity
|
||||
|
||||
__all__ = ["GenericEvent"]
|
||||
|
||||
|
||||
class GenericEvent(BaseModel):
|
||||
event: EventType
|
||||
"""Telemetry event type"""
|
||||
|
||||
event_id: str
|
||||
|
||||
event_name: str
|
||||
"""Low-cardinality event name"""
|
||||
|
||||
event_session_index: int
|
||||
|
||||
severity: Severity
|
||||
"""Log severity level"""
|
||||
|
||||
timestamp: datetime
|
||||
|
||||
event_data: Dict[str, object] = {}
|
||||
"""Arbitrary structured JSON payload"""
|
||||
34
src/tinker/types/generic_event_param.py
Normal file
34
src/tinker/types/generic_event_param.py
Normal file
|
|
@ -0,0 +1,34 @@
|
|||
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Dict, Union
|
||||
|
||||
from typing_extensions import Annotated, Required, TypedDict
|
||||
|
||||
from .._utils import PropertyInfo
|
||||
from .event_type import EventType
|
||||
from .severity import Severity
|
||||
|
||||
__all__ = ["GenericEventParam"]
|
||||
|
||||
|
||||
class GenericEventParam(TypedDict, total=False):
|
||||
event: Required[EventType]
|
||||
"""Telemetry event type"""
|
||||
|
||||
event_id: Required[str]
|
||||
|
||||
event_name: Required[str]
|
||||
"""Low-cardinality event name"""
|
||||
|
||||
event_session_index: Required[int]
|
||||
|
||||
severity: Required[Severity]
|
||||
"""Log severity level"""
|
||||
|
||||
timestamp: Required[Annotated[Union[str, datetime], PropertyInfo(format="iso8601")]]
|
||||
|
||||
event_data: Dict[str, object]
|
||||
"""Arbitrary structured JSON payload"""
|
||||
16
src/tinker/types/get_info_request.py
Normal file
16
src/tinker/types/get_info_request.py
Normal file
|
|
@ -0,0 +1,16 @@
|
|||
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
||||
|
||||
from typing import Optional
|
||||
from typing_extensions import Literal
|
||||
|
||||
from .._compat import PYDANTIC_V2, ConfigDict
|
||||
from .._models import StrictBase
|
||||
from .model_id import ModelID
|
||||
|
||||
__all__ = ["GetInfoRequest"]
|
||||
|
||||
|
||||
class GetInfoRequest(StrictBase):
|
||||
model_id: ModelID
|
||||
|
||||
type: Optional[Literal["get_info"]] = None
|
||||
32
src/tinker/types/get_info_response.py
Normal file
32
src/tinker/types/get_info_response.py
Normal file
|
|
@ -0,0 +1,32 @@
|
|||
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
||||
|
||||
from typing import Optional, Literal
|
||||
|
||||
from .._compat import PYDANTIC_V2, ConfigDict
|
||||
from .._models import BaseModel
|
||||
from .model_id import ModelID
|
||||
|
||||
__all__ = ["GetInfoResponse", "ModelData"]
|
||||
|
||||
|
||||
class ModelData(BaseModel):
|
||||
arch: Optional[str] = None
|
||||
|
||||
model_name: Optional[str] = None
|
||||
|
||||
class GetInfoResponse(BaseModel):
|
||||
type: Optional[Literal["get_info"]] = None
|
||||
|
||||
model_data: ModelData
|
||||
|
||||
model_id: ModelID
|
||||
|
||||
is_lora: Optional[bool] = None
|
||||
|
||||
lora_rank: Optional[int] = None
|
||||
|
||||
model_name: Optional[str] = None
|
||||
|
||||
if PYDANTIC_V2:
|
||||
# allow fields with a `model_` prefix
|
||||
model_config = ConfigDict(protected_namespaces=tuple())
|
||||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue