How to Obtain Runtime Data
To get data or gradient tensor in job function, we need to follow these steps:
Write a callback function and the parameters of the callback function should be annotated to indicate the data type. The logic of the callback function need to be set up by user themselves.
When defining the job functions, we use
oneflow.watch
oroneflow.watch_diff
to register callback function. We obtain data tensor from the former one and their corresponding gradient from the latter one.At the appropriate time when the job function is running, OneFlow will call the previous callback function which was registered earlier and pass the monitored data to the callback function then execute the logic in the callback function.
Take oneflow.watch
as example:
The T in the code above is the data type in oneflow.typing
. Like oneflow.typing.Numpy
. Please refer to this article.
The following is an example to demonstrate how to use to obtain the data from middle layer in OneFlow. Code:
Run above code:
python3 test_watch.py
We can get results like the followings:
In the example, we focus on y
in ReluJob
. Thus, we call flow.watch(y, watch_handler)
to monitor y
. The function oneflow.watch
needs two parameters:
The second parameter is a callback function. When OneFlow use device resources to execute
ReluJob
, it will sendy
as a parameter to callback function. We define our callback function to print out its parameters.
The following is an example to demonstrate how to use oneflow.watch_diff
to obtain the gradient at runtime.
Code: test_watch_diff.py
Run above code:
python3 test_watch.py
We should have the following results:
Code Explanation
In the example above, we use oneflow.watch_diff
to obtain the gradient. The processe is the same as the example which using oneflow.watch
to obtain data tensor.
First, we define the callback function:
def watch_diff_handler(blob: tp.Numpy):
print("watch_diff_handler:", blob, blob.shape, blob.dtype)
When running, OneFlow framework will call watch_diff_handler
and send the gradient corresponding with to watch_diff_handler
.