forked from google-deepmind/deepmind-research
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmetadata_schema.py
137 lines (111 loc) · 4.65 KB
/
metadata_schema.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
# Copyright 2020 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Sqlalchemy schema for the metadata db."""
import sqlalchemy
from sqlalchemy.ext import declarative
Column = sqlalchemy.Column
Integer = sqlalchemy.Integer
String = sqlalchemy.String
LargeBinary = sqlalchemy.LargeBinary
ForeignKey = sqlalchemy.ForeignKey
# pylint: disable=invalid-name
# https://docs.sqlalchemy.org/en/13/orm/tutorial.html
Base = declarative.declarative_base()
EpisodeTag = sqlalchemy.Table(
'EpisodeTags', Base.metadata,
Column(
'EpisodeId', String, ForeignKey('Episodes.EpisodeId'),
primary_key=True),
Column('Tag', String, ForeignKey('Tags.Name'), primary_key=True))
"""Table relating episodes and tags.
Attributes:
EpisodeId: A string of digits that uniquely identifies the episode.
Tag: Human readable tag name.
"""
class Episode(Base):
"""Table describing individual episodes.
Attributes:
EpisodeId: A string of digits that uniquely identifies the episode.
TaskId: A human readable name for the task corresponding to the behavior
that generated the episode.
DataPath: The name of the episode file holding the data for this episode.
Timestamp: A unix timestamp recording when the episode was generated.
EpisodeType: A string describing the type of policy that generated the
episode. Possible values are:
- `EPISODE_ROBOT_AGENT`: The behavior policy is a learned or scripted
controller.
- `EPISODE_ROBOT_TELEOPERATION`: The behavior policy is a human
teleoperating the robot.
- `EPISODE_ROBOT_DAGGER`: The behavior policy is a mix of controller
and human generated actions.
Tags: A list of tags attached to this episode.
Rewards: A list of `RewardSequence`s containing sketched rewards for this
episode.
"""
__tablename__ = 'Episodes'
EpisodeId = Column(String, primary_key=True)
TaskId = Column(String)
DataPath = Column(String)
Timestamp = Column(Integer)
EpisodeType = Column(String)
Tags = sqlalchemy.orm.relationship(
'Tag', secondary=EpisodeTag, back_populates='Episodes')
Rewards = sqlalchemy.orm.relationship(
'RewardSequence', backref='Episode')
class Tag(Base):
"""Table of tags that can be attached to episodes.
Attributes:
Name: Human readable tag name.
Episodes: The epsidoes that have been annotated with this tag.
"""
__tablename__ = 'Tags'
Name = Column(String, primary_key=True)
Episodes = sqlalchemy.orm.relationship(
'Episode', secondary=EpisodeTag, back_populates='Tags')
class RewardSequence(Base):
"""Table describing reward sequences for episodes.
Attributes:
EpisodeId: Foreign key into the `Episodes` table.
RewardSequenceId: Distinguishes multiple rewards for the same episode.
RewardTaskId: A human readable name of the task for this reward signal.
Typically the same as the corresponding `TaskId` in the `Episodes`
table.
Type: A string describing the type of reward signal. Currently the only
value is `REWARD_SKETCH`.
User: The name of the user who produced this reward sequence.
Values: A sequence of float32 values, packed as a binary blob. There is one
float value for each frame of the episode, corresponding to the
annotated reward.
"""
__tablename__ = 'RewardSequences'
EpisodeId = Column(
'EpisodeId', String, ForeignKey('Episodes.EpisodeId'), primary_key=True)
RewardSequenceId = Column(String, primary_key=True)
RewardTaskId = Column('RewardTaskId', String)
Type = Column(String)
User = Column(String)
Values = Column(LargeBinary)
class ArchiveFile(Base):
"""Table describing where episodes are stored in archives.
This information is relevant if you want to download or extract a specific
episode from the archives they are distributed in.
Attributes:
EpisodeId: Foreign key into the `Episodes` table.
ArchiveFile: Name of the archive file containing the corresponding episode.
"""
__tablename__ = 'ArchiveFiles'
EpisodeId = Column(
'EpisodeId', String, ForeignKey('Episodes.EpisodeId'), primary_key=True)
ArchiveFile = Column(String)
# pylint: enable=invalid-name