@@ -152,32 +152,44 @@ void MatrixTimesVector::prepare() {
152
152
void MatrixTimesVector::setupForTask ( const unsigned & task_index, std::vector<unsigned >& indices, MultiValue& myvals ) const {
153
153
unsigned start_n = getPntrToArgument (0 )->getShape ()[0 ], size_v = getPntrToArgument (0 )->getRowLength (task_index);
154
154
if ( indices.size ()!=size_v+1 ) indices.resize ( size_v + 1 );
155
- for (unsigned i=0 ; i<size_v; ++i) indices[i+1 ] = start_n + getPntrToArgument ( 0 )-> getRowIndex ( task_index, i ) ;
155
+ for (unsigned i=0 ; i<size_v; ++i) indices[i+1 ] = start_n + i ;
156
156
myvals.setSplitIndex ( size_v + 1 );
157
157
}
158
158
159
159
void MatrixTimesVector::performTask ( const std::string& controller, const unsigned & index1, const unsigned & index2, MultiValue& myvals ) const {
160
160
unsigned ind2 = index2; if ( index2>=getPntrToArgument (0 )->getShape ()[0 ] ) ind2 = index2 - getPntrToArgument (0 )->getShape ()[0 ];
161
161
if ( getPntrToArgument (1 )->getRank ()==1 ) {
162
+ double matval = 0 ; Value* myarg = getPntrToArgument (0 ); unsigned vcol = ind2;
163
+ if ( !myarg->valueHasBeenSet () ) matval = myvals.get ( myarg->getPositionInStream () );
164
+ else {
165
+ matval = myarg->get ( index1*myarg->getNumberOfColumns () + ind2, false );
166
+ vcol = getPntrToArgument (0 )->getRowIndex ( index1, ind2 );
167
+ }
162
168
for (unsigned i=0 ; i<getNumberOfArguments ()-1 ; ++i) {
163
169
unsigned ostrn = getConstPntrToComponent (i)->getPositionInStream ();
164
- double matval = getElementOfMatrixArgument ( 0 , index1, ind2, myvals ), vecval=getArgumentElement ( i+1 , ind2 , myvals );
170
+ double vecval=getArgumentElement ( i+1 , vcol , myvals );
165
171
// And add this part of the product
166
172
myvals.addValue ( ostrn, matval*vecval );
167
173
// Now lets work out the derivatives
168
174
if ( doNotCalculateDerivatives () ) continue ;
169
- addDerivativeOnMatrixArgument ( stored_arg[0 ], i, 0 , index1, ind2, vecval, myvals ); addDerivativeOnVectorArgument ( stored_arg[i+1 ], i, i+1 , ind2 , matval, myvals );
175
+ addDerivativeOnMatrixArgument ( stored_arg[0 ], i, 0 , index1, ind2, vecval, myvals ); addDerivativeOnVectorArgument ( stored_arg[i+1 ], i, i+1 , vcol , matval, myvals );
170
176
}
171
177
} else {
172
- unsigned n=getNumberOfArguments ()-1 ;
178
+ unsigned n=getNumberOfArguments ()-1 ; double matval = 0 ; unsigned vcol = ind2;
173
179
for (unsigned i=0 ; i<getNumberOfArguments ()-1 ; ++i) {
174
180
unsigned ostrn = getConstPntrToComponent (i)->getPositionInStream ();
175
- double matval = getElementOfMatrixArgument ( i, index1, ind2, myvals ), vecval=getArgumentElement ( n, ind2, myvals );
181
+ Value* myarg = getPntrToArgument (i);
182
+ if ( !myarg->valueHasBeenSet () ) matval = myvals.get ( myarg->getPositionInStream () );
183
+ else {
184
+ matval = myarg->get ( index1*myarg->getNumberOfColumns () + ind2, false );
185
+ vcol = getPntrToArgument (i)->getRowIndex ( index1, ind2 );
186
+ }
187
+ double vecval=getArgumentElement ( n, vcol, myvals );
176
188
// And add this part of the product
177
189
myvals.addValue ( ostrn, matval*vecval );
178
190
// Now lets work out the derivatives
179
191
if ( doNotCalculateDerivatives () ) continue ;
180
- addDerivativeOnMatrixArgument ( stored_arg[i], i, i, index1, ind2, vecval, myvals ); addDerivativeOnVectorArgument ( stored_arg[n], i, n, ind2 , matval, myvals );
192
+ addDerivativeOnMatrixArgument ( stored_arg[i], i, i, index1, ind2, vecval, myvals ); addDerivativeOnVectorArgument ( stored_arg[n], i, n, vcol , matval, myvals );
181
193
}
182
194
}
183
195
}
0 commit comments