#include<cstdio>
#include<cctype>
const long long maxn=998244353;
using namespace std;
inline long long get();
long long n,m;
long long ans;
int main()
{
freopen("bpmp.in","r",stdin);
freopen("bpmp.out","w",stdout);
n=get();m=get();
ans=((m-1)+m*(n-1))%maxn;
printf("%lld",ans);
return 0;
}
inline long long get()
{
long long t=0,jud=1;char c=getchar();
while(!isdigit(c))
{
if(c=='-')jud=-1;
c=getchar();
}
while(isdigit(c))
{
t=(t<<3)+(t<<1)+c-'0';
c=getchar();
}
return t*jud;
}